Skip to content

Commit 595e386

Browse files
authored
Merge branch 'devel' into debug-weightedavg
2 parents 258afbb + 47bbd65 commit 595e386

8 files changed

Lines changed: 193 additions & 4 deletions

File tree

.github/workflows/test_cc.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@ jobs:
2929
- run: python -m pip install uv
3030
- name: Install Python dependencies
3131
run: |
32-
source/install/uv_with_retry.sh pip install --system tensorflow-cpu
32+
source/install/uv_with_retry.sh pip install --system tensorflow-cpu~=2.18.0 jax==0.5.0
3333
export TENSORFLOW_ROOT=$(python -c 'import importlib,pathlib;print(pathlib.Path(importlib.util.find_spec("tensorflow").origin).parent)')
3434
source/install/uv_with_retry.sh pip install --system -e .[cpu,test,lmp,jax] mpi4py
3535
- name: Convert models

.github/workflows/test_cuda.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -47,7 +47,7 @@ jobs:
4747
&& sudo apt-get -y install cuda-12-3 libcudnn8=8.9.5.*-1+cuda12.3
4848
if: false # skip as we use nvidia image
4949
- run: python -m pip install -U uv
50-
- run: source/install/uv_with_retry.sh pip install --system "tensorflow~=2.18.0rc2" "torch~=2.6.0" "jax[cuda12]"
50+
- run: source/install/uv_with_retry.sh pip install --system "tensorflow~=2.18.0rc2" "torch~=2.6.0" "jax[cuda12]==0.5.0"
5151
- run: |
5252
export PYTORCH_ROOT=$(python -c 'import torch;print(torch.__path__[0])')
5353
export TENSORFLOW_ROOT=$(python -c 'import importlib,pathlib;print(pathlib.Path(importlib.util.find_spec("tensorflow").origin).parent)')

.github/workflows/test_python.yml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -25,11 +25,11 @@ jobs:
2525
python-version: ${{ matrix.python }}
2626
- run: python -m pip install -U uv
2727
- run: |
28-
source/install/uv_with_retry.sh pip install --system openmpi tensorflow-cpu
28+
source/install/uv_with_retry.sh pip install --system openmpi tensorflow-cpu~=2.18.0
2929
source/install/uv_with_retry.sh pip install --system torch -i https://download.pytorch.org/whl/cpu
3030
export TENSORFLOW_ROOT=$(python -c 'import tensorflow;print(tensorflow.__path__[0])')
3131
export PYTORCH_ROOT=$(python -c 'import torch;print(torch.__path__[0])')
32-
source/install/uv_with_retry.sh pip install --system -e .[test,jax] mpi4py
32+
source/install/uv_with_retry.sh pip install --system -e .[test,jax] mpi4py "jax==0.5.0;python_version>='3.10'"
3333
source/install/uv_with_retry.sh pip install --system horovod --no-build-isolation
3434
source/install/uv_with_retry.sh pip install --system --pre "paddlepaddle" -i https://www.paddlepaddle.org.cn/packages/nightly/cpu/
3535
env:

deepmd/pt/model/atomic_model/base_atomic_model.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -79,6 +79,7 @@ def __init__(
7979
pair_exclude_types: list[tuple[int, int]] = [],
8080
rcond: Optional[float] = None,
8181
preset_out_bias: Optional[dict[str, np.ndarray]] = None,
82+
data_stat_protect: float = 1e-2,
8283
) -> None:
8384
torch.nn.Module.__init__(self)
8485
BaseAtomicModel_.__init__(self)
@@ -87,6 +88,7 @@ def __init__(
8788
self.reinit_pair_exclude(pair_exclude_types)
8889
self.rcond = rcond
8990
self.preset_out_bias = preset_out_bias
91+
self.data_stat_protect = data_stat_protect
9092

9193
def init_out_stat(self) -> None:
9294
"""Initialize the output bias."""

deepmd/pt/model/atomic_model/dp_atomic_model.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -299,6 +299,9 @@ def wrapped_sampler():
299299
return sampled
300300

301301
self.descriptor.compute_input_stats(wrapped_sampler, stat_file_path)
302+
self.fitting_net.compute_input_stats(
303+
wrapped_sampler, protection=self.data_stat_protect
304+
)
302305
self.compute_or_load_out_stat(wrapped_sampler, stat_file_path)
303306

304307
def get_dim_fparam(self) -> int:

deepmd/pt/model/model/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -254,6 +254,7 @@ def get_standard_model(model_params):
254254
preset_out_bias = _convert_preset_out_bias_to_array(
255255
preset_out_bias, model_params["type_map"]
256256
)
257+
data_stat_protect = model_params.get("data_stat_protect", 1e-2)
257258

258259
if fitting_net_type == "dipole":
259260
modelcls = DipoleModel
@@ -275,6 +276,7 @@ def get_standard_model(model_params):
275276
atom_exclude_types=atom_exclude_types,
276277
pair_exclude_types=pair_exclude_types,
277278
preset_out_bias=preset_out_bias,
279+
data_stat_protect=data_stat_protect,
278280
)
279281
if model_params.get("hessian_mode"):
280282
model.enable_hessian()

deepmd/pt/model/task/fitting.py

Lines changed: 79 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
abstractmethod,
55
)
66
from typing import (
7+
Callable,
78
Optional,
89
Union,
910
)
@@ -71,6 +72,84 @@ def share_params(self, base_class, shared_level, resume=False) -> None:
7172
else:
7273
raise NotImplementedError
7374

75+
def compute_input_stats(
76+
self,
77+
merged: Union[Callable[[], list[dict]], list[dict]],
78+
protection: float = 1e-2,
79+
) -> None:
80+
"""
81+
Compute the input statistics (e.g. mean and stddev) for the fittings from packed data.
82+
83+
Parameters
84+
----------
85+
merged : Union[Callable[[], list[dict]], list[dict]]
86+
- list[dict]: A list of data samples from various data systems.
87+
Each element, `merged[i]`, is a data dictionary containing `keys`: `torch.Tensor`
88+
originating from the `i`-th data system.
89+
- Callable[[], list[dict]]: A lazy function that returns data samples in the above format
90+
only when needed. Since the sampling process can be slow and memory-intensive,
91+
the lazy function helps by only sampling once.
92+
protection : float
93+
Divided-by-zero protection
94+
"""
95+
if callable(merged):
96+
sampled = merged()
97+
else:
98+
sampled = merged
99+
# stat fparam
100+
if self.numb_fparam > 0:
101+
cat_data = torch.cat([frame["fparam"] for frame in sampled], dim=0)
102+
cat_data = torch.reshape(cat_data, [-1, self.numb_fparam])
103+
fparam_avg = torch.mean(cat_data, dim=0)
104+
fparam_std = torch.std(cat_data, dim=0, unbiased=False)
105+
fparam_std = torch.where(
106+
fparam_std < protection,
107+
torch.tensor(
108+
protection, dtype=fparam_std.dtype, device=fparam_std.device
109+
),
110+
fparam_std,
111+
)
112+
fparam_inv_std = 1.0 / fparam_std
113+
self.fparam_avg.copy_(
114+
torch.tensor(fparam_avg, device=env.DEVICE, dtype=self.fparam_avg.dtype)
115+
)
116+
self.fparam_inv_std.copy_(
117+
torch.tensor(
118+
fparam_inv_std, device=env.DEVICE, dtype=self.fparam_inv_std.dtype
119+
)
120+
)
121+
# stat aparam
122+
if self.numb_aparam > 0:
123+
sys_sumv = []
124+
sys_sumv2 = []
125+
sys_sumn = []
126+
for ss_ in [frame["aparam"] for frame in sampled]:
127+
ss = torch.reshape(ss_, [-1, self.numb_aparam])
128+
sys_sumv.append(torch.sum(ss, dim=0))
129+
sys_sumv2.append(torch.sum(ss * ss, dim=0))
130+
sys_sumn.append(ss.shape[0])
131+
sumv = torch.sum(torch.stack(sys_sumv), dim=0)
132+
sumv2 = torch.sum(torch.stack(sys_sumv2), dim=0)
133+
sumn = sum(sys_sumn)
134+
aparam_avg = sumv / sumn
135+
aparam_std = torch.sqrt(sumv2 / sumn - (sumv / sumn) ** 2)
136+
aparam_std = torch.where(
137+
aparam_std < protection,
138+
torch.tensor(
139+
protection, dtype=aparam_std.dtype, device=aparam_std.device
140+
),
141+
aparam_std,
142+
)
143+
aparam_inv_std = 1.0 / aparam_std
144+
self.aparam_avg.copy_(
145+
torch.tensor(aparam_avg, device=env.DEVICE, dtype=self.aparam_avg.dtype)
146+
)
147+
self.aparam_inv_std.copy_(
148+
torch.tensor(
149+
aparam_inv_std, device=env.DEVICE, dtype=self.aparam_inv_std.dtype
150+
)
151+
)
152+
74153

75154
class GeneralFitting(Fitting):
76155
"""Construct a general fitting net.
Lines changed: 103 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,103 @@
1+
# SPDX-License-Identifier: LGPL-3.0-or-later
2+
import unittest
3+
4+
import numpy as np
5+
6+
from deepmd.pt.model.descriptor import (
7+
DescrptSeA,
8+
)
9+
from deepmd.pt.model.task import (
10+
EnergyFittingNet,
11+
)
12+
from deepmd.pt.utils.utils import (
13+
to_numpy_array,
14+
to_torch_tensor,
15+
)
16+
17+
18+
def _make_fake_data_pt(sys_natoms, sys_nframes, avgs, stds):
19+
merged_output_stat = []
20+
nsys = len(sys_natoms)
21+
ndof = len(avgs)
22+
for ii in range(nsys):
23+
sys_dict = {}
24+
tmp_data_f = []
25+
tmp_data_a = []
26+
for jj in range(ndof):
27+
rng = np.random.default_rng(2025 * ii + 220 * jj)
28+
tmp_data_f.append(
29+
rng.normal(loc=avgs[jj], scale=stds[jj], size=(sys_nframes[ii], 1))
30+
)
31+
rng = np.random.default_rng(220 * ii + 1636 * jj)
32+
tmp_data_a.append(
33+
rng.normal(
34+
loc=avgs[jj], scale=stds[jj], size=(sys_nframes[ii], sys_natoms[ii])
35+
)
36+
)
37+
tmp_data_f = np.transpose(tmp_data_f, (1, 2, 0))
38+
tmp_data_a = np.transpose(tmp_data_a, (1, 2, 0))
39+
sys_dict["fparam"] = to_torch_tensor(tmp_data_f)
40+
sys_dict["aparam"] = to_torch_tensor(tmp_data_a)
41+
merged_output_stat.append(sys_dict)
42+
return merged_output_stat
43+
44+
45+
def _brute_fparam_pt(data, ndim):
46+
adata = [to_numpy_array(ii["fparam"]) for ii in data]
47+
all_data = []
48+
for ii in adata:
49+
tmp = np.reshape(ii, [-1, ndim])
50+
if len(all_data) == 0:
51+
all_data = np.array(tmp)
52+
else:
53+
all_data = np.concatenate((all_data, tmp), axis=0)
54+
avg = np.average(all_data, axis=0)
55+
std = np.std(all_data, axis=0)
56+
return avg, std
57+
58+
59+
def _brute_aparam_pt(data, ndim):
60+
adata = [to_numpy_array(ii["aparam"]) for ii in data]
61+
all_data = []
62+
for ii in adata:
63+
tmp = np.reshape(ii, [-1, ndim])
64+
if len(all_data) == 0:
65+
all_data = np.array(tmp)
66+
else:
67+
all_data = np.concatenate((all_data, tmp), axis=0)
68+
avg = np.average(all_data, axis=0)
69+
std = np.std(all_data, axis=0)
70+
return avg, std
71+
72+
73+
class TestEnerFittingStat(unittest.TestCase):
74+
def test(self) -> None:
75+
descrpt = DescrptSeA(6.0, 5.8, [46, 92], neuron=[25, 50, 100], axis_neuron=16)
76+
fitting = EnergyFittingNet(
77+
descrpt.get_ntypes(),
78+
descrpt.get_dim_out(),
79+
neuron=[240, 240, 240],
80+
resnet_dt=True,
81+
numb_fparam=3,
82+
numb_aparam=3,
83+
)
84+
avgs = [0, 10, 100]
85+
stds = [2, 0.4, 0.00001]
86+
sys_natoms = [10, 100]
87+
sys_nframes = [5, 2]
88+
all_data = _make_fake_data_pt(sys_natoms, sys_nframes, avgs, stds)
89+
frefa, frefs = _brute_fparam_pt(all_data, len(avgs))
90+
arefa, arefs = _brute_aparam_pt(all_data, len(avgs))
91+
fitting.compute_input_stats(all_data, protection=1e-2)
92+
frefs_inv = 1.0 / frefs
93+
arefs_inv = 1.0 / arefs
94+
frefs_inv[frefs_inv > 100] = 100
95+
arefs_inv[arefs_inv > 100] = 100
96+
np.testing.assert_almost_equal(frefa, to_numpy_array(fitting.fparam_avg))
97+
np.testing.assert_almost_equal(
98+
frefs_inv, to_numpy_array(fitting.fparam_inv_std)
99+
)
100+
np.testing.assert_almost_equal(arefa, to_numpy_array(fitting.aparam_avg))
101+
np.testing.assert_almost_equal(
102+
arefs_inv, to_numpy_array(fitting.aparam_inv_std)
103+
)

0 commit comments

Comments
 (0)