Skip to content

Commit 4b58c91

Browse files
authored
Merge branch 'devel' into add_random_fit_finetune
2 parents 4a8a109 + 4e72a97 commit 4b58c91

50 files changed

Lines changed: 4745 additions & 528 deletions

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

.github/workflows/test_cuda.yml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -60,6 +60,7 @@ jobs:
6060
- run: python -m pytest source/tests --durations=0
6161
env:
6262
NUM_WORKERS: 0
63+
CUDA_VISIBLE_DEVICES: 0
6364
- name: Download libtorch
6465
run: |
6566
wget https://download.pytorch.org/libtorch/cu121/libtorch-cxx11-abi-shared-with-deps-2.2.1%2Bcu121.zip -O libtorch.zip

deepmd/dpmodel/atomic_model/linear_atomic_model.py

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -196,11 +196,11 @@ def forward_atomic(
196196
]
197197
ener_list = []
198198
for i, model in enumerate(self.models):
199-
mapping = self.mapping_list[i]
199+
type_map_model = self.mapping_list[i]
200200
ener_list.append(
201201
model.forward_atomic(
202202
extended_coord,
203-
mapping[extended_atype],
203+
type_map_model[extended_atype],
204204
nlists_[i],
205205
mapping,
206206
fparam,
@@ -414,7 +414,12 @@ def _compute_weight(
414414
)
415415

416416
numerator = np.sum(
417-
pairwise_rr * np.exp(-pairwise_rr / self.smin_alpha), axis=-1
417+
np.where(
418+
nlist_larger != -1,
419+
pairwise_rr * np.exp(-pairwise_rr / self.smin_alpha),
420+
np.zeros_like(nlist_larger),
421+
),
422+
axis=-1,
418423
) # masked nnei will be zero, no need to handle
419424
denominator = np.sum(
420425
np.where(
@@ -436,5 +441,7 @@ def _compute_weight(
436441
smooth = -6 * u**5 + 15 * u**4 - 10 * u**3 + 1
437442
coef[mid_mask] = smooth[mid_mask]
438443
coef[right_mask] = 0
444+
# to handle masked atoms
445+
coef = np.where(sigma != 0, coef, np.zeros_like(coef))
439446
self.zbl_weight = coef
440447
return [1 - np.expand_dims(coef, -1), np.expand_dims(coef, -1)]

deepmd/dpmodel/common.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424
"double": np.float64,
2525
"int32": np.int32,
2626
"int64": np.int64,
27+
"bool": bool,
2728
"default": GLOBAL_NP_FLOAT_PRECISION,
2829
# NumPy doesn't have bfloat16 (and does't plan to add)
2930
# ml_dtypes is a solution, but it seems not supporting np.save/np.load
@@ -39,6 +40,7 @@
3940
np.int32: "int32",
4041
np.int64: "int64",
4142
ml_dtypes.bfloat16: "bfloat16",
43+
bool: "bool",
4244
}
4345
assert set(RESERVED_PRECISON_DICT.keys()) == set(PRECISION_DICT.values())
4446
DEFAULT_PRECISION = "float64"

deepmd/dpmodel/descriptor/dpa1.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -61,6 +61,7 @@
6161

6262

6363
def np_softmax(x, axis=-1):
64+
x = np.nan_to_num(x) # to avoid value warning
6465
e_x = np.exp(x - np.max(x, axis=axis, keepdims=True))
6566
return e_x / np.sum(e_x, axis=axis, keepdims=True)
6667

deepmd/dpmodel/model/__init__.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,9 @@
1515
from .dp_model import (
1616
DPModelCommon,
1717
)
18+
from .ener_model import (
19+
EnergyModel,
20+
)
1821
from .make_model import (
1922
make_model,
2023
)
@@ -23,6 +26,7 @@
2326
)
2427

2528
__all__ = [
29+
"EnergyModel",
2630
"DPModelCommon",
2731
"SpinModel",
2832
"make_model",

deepmd/dpmodel/model/spin_model.py

Lines changed: 66 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -10,15 +10,21 @@
1010
from deepmd.dpmodel.atomic_model.dp_atomic_model import (
1111
DPAtomicModel,
1212
)
13+
from deepmd.dpmodel.common import (
14+
NativeOP,
15+
)
1316
from deepmd.dpmodel.model.make_model import (
1417
make_model,
1518
)
19+
from deepmd.dpmodel.output_def import (
20+
ModelOutputDef,
21+
)
1622
from deepmd.utils.spin import (
1723
Spin,
1824
)
1925

2026

21-
class SpinModel:
27+
class SpinModel(NativeOP):
2228
"""A spin model wrapper, with spin input preprocess and output split."""
2329

2430
def __init__(
@@ -152,15 +158,20 @@ def extend_nlist(extended_atype, nlist):
152158
nlist_shift = nlist + nall
153159
nlist[~nlist_mask] = -1
154160
nlist_shift[~nlist_mask] = -1
155-
self_spin = np.arange(0, nloc, dtype=nlist.dtype) + nall
156-
self_spin = self_spin.reshape(1, -1, 1).repeat(nframes, axis=0)
157-
# self spin + real neighbor + virtual neighbor
161+
self_real = (
162+
np.arange(0, nloc, dtype=nlist.dtype)
163+
.reshape(1, -1, 1)
164+
.repeat(nframes, axis=0)
165+
)
166+
self_spin = self_real + nall
167+
# real atom's neighbors: self spin + real neighbor + virtual neighbor
168+
# nf x nloc x (1 + nnei + nnei)
169+
real_nlist = np.concatenate([self_spin, nlist, nlist_shift], axis=-1)
170+
# spin atom's neighbors: real + real neighbor + virtual neighbor
158171
# nf x nloc x (1 + nnei + nnei)
159-
extended_nlist = np.concatenate([self_spin, nlist, nlist_shift], axis=-1)
172+
spin_nlist = np.concatenate([self_real, nlist, nlist_shift], axis=-1)
160173
# nf x (nloc + nloc) x (1 + nnei + nnei)
161-
extended_nlist = np.concatenate(
162-
[extended_nlist, -1 * np.ones_like(extended_nlist)], axis=-2
163-
)
174+
extended_nlist = np.concatenate([real_nlist, spin_nlist], axis=-2)
164175
# update the index for switch
165176
first_part_index = (nloc <= extended_nlist) & (extended_nlist < nall)
166177
second_part_index = (nall <= extended_nlist) & (extended_nlist < (nall + nloc))
@@ -187,12 +198,40 @@ def concat_switch_virtual(extended_tensor, extended_tensor_virtual, nloc: int):
187198
extended_tensor_updated[:, nloc + nall :] = extended_tensor_virtual[:, nloc:]
188199
return extended_tensor_updated.reshape(out_shape)
189200

201+
@staticmethod
202+
def expand_aparam(aparam, nloc: int):
203+
"""Expand the atom parameters for virtual atoms if necessary."""
204+
nframes, natom, numb_aparam = aparam.shape
205+
if natom == nloc: # good
206+
pass
207+
elif natom < nloc: # for spin with virtual atoms
208+
aparam = np.concatenate(
209+
[
210+
aparam,
211+
np.zeros(
212+
[nframes, nloc - natom, numb_aparam],
213+
dtype=aparam.dtype,
214+
),
215+
],
216+
axis=1,
217+
)
218+
else:
219+
raise ValueError(
220+
f"get an input aparam with {aparam.shape[1]} inputs, ",
221+
f"which is larger than {nloc} atoms.",
222+
)
223+
return aparam
224+
190225
def get_type_map(self) -> List[str]:
191226
"""Get the type map."""
192227
tmap = self.backbone_model.get_type_map()
193228
ntypes = len(tmap) // 2 # ignore the virtual type
194229
return tmap[:ntypes]
195230

231+
def get_ntypes(self):
232+
"""Returns the number of element types."""
233+
return len(self.get_type_map())
234+
196235
def get_rcut(self):
197236
"""Get the cut-off radius."""
198237
return self.backbone_model.get_rcut()
@@ -251,6 +290,16 @@ def has_spin() -> bool:
251290
"""Returns whether it has spin input and output."""
252291
return True
253292

293+
def model_output_def(self):
294+
"""Get the output def for the model."""
295+
model_output_type = self.backbone_model.model_output_type()
296+
if "mask" in model_output_type:
297+
model_output_type.pop(model_output_type.index("mask"))
298+
var_name = model_output_type[0]
299+
backbone_model_atomic_output_def = self.backbone_model.atomic_output_def()
300+
backbone_model_atomic_output_def[var_name].magnetic = True
301+
return ModelOutputDef(backbone_model_atomic_output_def)
302+
254303
def __getattr__(self, name):
255304
"""Get attribute from the wrapped model."""
256305
if name in self.__dict__:
@@ -313,8 +362,12 @@ def call(
313362
The keys are defined by the `ModelOutputDef`.
314363
315364
"""
316-
nframes, nloc = coord.shape[:2]
365+
nframes, nloc = atype.shape[:2]
366+
coord = coord.reshape(nframes, nloc, 3)
367+
spin = spin.reshape(nframes, nloc, 3)
317368
coord_updated, atype_updated = self.process_spin_input(coord, atype, spin)
369+
if aparam is not None:
370+
aparam = self.expand_aparam(aparam, nloc * 2)
318371
model_predict = self.backbone_model.call(
319372
coord_updated,
320373
atype_updated,
@@ -383,6 +436,8 @@ def call_lower(
383436
) = self.process_spin_input_lower(
384437
extended_coord, extended_atype, extended_spin, nlist, mapping=mapping
385438
)
439+
if aparam is not None:
440+
aparam = self.expand_aparam(aparam, nloc * 2)
386441
model_predict = self.backbone_model.call_lower(
387442
extended_coord_updated,
388443
extended_atype_updated,
@@ -401,3 +456,5 @@ def call_lower(
401456
)[0]
402457
# for now omit the grad output
403458
return model_predict
459+
460+
forward_lower = call_lower

deepmd/dpmodel/output_def.py

Lines changed: 15 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -228,6 +228,11 @@ def __init__(
228228
def size(self):
229229
return self.output_size
230230

231+
def squeeze(self, dim):
232+
# squeeze the shape on given dimension
233+
if -len(self.shape) <= dim < len(self.shape) and self.shape[dim] == 1:
234+
self.shape.pop(dim)
235+
231236

232237
class FittingOutputDef:
233238
"""Defines the shapes and other properties of the fitting network outputs.
@@ -306,7 +311,6 @@ def __getitem__(
306311

307312
def get_data(
308313
self,
309-
key: str,
310314
) -> Dict[str, OutputVariableDef]:
311315
return self.var_defs
312316

@@ -402,6 +406,16 @@ def check_operation_applied(
402406
return var_def.category & op.value == op.value
403407

404408

409+
def check_deriv(var_def: OutputVariableDef) -> bool:
410+
"""Check if a variable is obtained by derivative."""
411+
deriv = (
412+
check_operation_applied(var_def, OutputVariableOperation.DERV_R)
413+
or check_operation_applied(var_def, OutputVariableOperation._SEC_DERV_R)
414+
or check_operation_applied(var_def, OutputVariableOperation.DERV_C)
415+
)
416+
return deriv
417+
418+
405419
def do_reduce(
406420
def_outp_data: Dict[str, OutputVariableDef],
407421
) -> Dict[str, OutputVariableDef]:

deepmd/pt/model/atomic_model/linear_atomic_model.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -224,12 +224,12 @@ def forward_atomic(
224224
ener_list = []
225225

226226
for i, model in enumerate(self.models):
227-
mapping = self.mapping_list[i]
227+
type_map_model = self.mapping_list[i].to(extended_atype.device)
228228
# apply bias to each individual model
229229
ener_list.append(
230230
model.forward_common_atomic(
231231
extended_coord,
232-
mapping[extended_atype],
232+
type_map_model[extended_atype],
233233
nlists_[i],
234234
mapping,
235235
fparam,
@@ -239,7 +239,10 @@ def forward_atomic(
239239
weights = self._compute_weight(extended_coord, extended_atype, nlists_)
240240

241241
fit_ret = {
242-
"energy": torch.sum(torch.stack(ener_list) * torch.stack(weights), dim=0),
242+
"energy": torch.sum(
243+
torch.stack(ener_list) * torch.stack(weights).to(extended_atype.device),
244+
dim=0,
245+
),
243246
} # (nframes, nloc, 1)
244247
return fit_ret
245248

deepmd/pt/model/atomic_model/polar_atomic_model.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -49,7 +49,8 @@ def apply_out_stat(
4949

5050
# (nframes, nloc, 1)
5151
modified_bias = (
52-
modified_bias.unsqueeze(-1) * self.fitting_net.scale[atype]
52+
modified_bias.unsqueeze(-1)
53+
* (self.fitting_net.scale.to(atype.device))[atype]
5354
)
5455

5556
eye = torch.eye(3, dtype=dtype, device=device)

deepmd/pt/model/descriptor/hybrid.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,8 @@ class DescrptHybrid(BaseDescriptor, torch.nn.Module):
4343
The descriptor can be either an object or a dictionary.
4444
"""
4545

46+
nlist_cut_idx: List[torch.Tensor]
47+
4648
def __init__(
4749
self,
4850
list: List[Union[BaseDescriptor, Dict[str, Any]]],
@@ -278,11 +280,13 @@ def forward(
278280
for ii, descrpt in enumerate(self.descrpt_list):
279281
# cut the nlist to the correct length
280282
if self.mixed_types() == descrpt.mixed_types():
281-
nl = nlist[:, :, self.nlist_cut_idx[ii]]
283+
nl = nlist[:, :, self.nlist_cut_idx[ii].to(atype_ext.device)]
282284
else:
283285
# mixed_types is True, but descrpt.mixed_types is False
284286
assert nl_distinguish_types is not None
285-
nl = nl_distinguish_types[:, :, self.nlist_cut_idx[ii]]
287+
nl = nl_distinguish_types[
288+
:, :, self.nlist_cut_idx[ii].to(atype_ext.device)
289+
]
286290
odescriptor, gr, g2, h2, sw = descrpt(coord_ext, atype_ext, nl, mapping)
287291
out_descriptor.append(odescriptor)
288292
if gr is not None:

0 commit comments

Comments
 (0)