1010from deepmd .dpmodel .atomic_model .dp_atomic_model import (
1111 DPAtomicModel ,
1212)
13+ from deepmd .dpmodel .common import (
14+ NativeOP ,
15+ )
1316from deepmd .dpmodel .model .make_model import (
1417 make_model ,
1518)
19+ from deepmd .dpmodel .output_def import (
20+ ModelOutputDef ,
21+ )
1622from deepmd .utils .spin import (
1723 Spin ,
1824)
1925
2026
21- class SpinModel :
27+ class SpinModel ( NativeOP ) :
2228 """A spin model wrapper, with spin input preprocess and output split."""
2329
2430 def __init__ (
@@ -152,15 +158,20 @@ def extend_nlist(extended_atype, nlist):
152158 nlist_shift = nlist + nall
153159 nlist [~ nlist_mask ] = - 1
154160 nlist_shift [~ nlist_mask ] = - 1
155- self_spin = np .arange (0 , nloc , dtype = nlist .dtype ) + nall
156- self_spin = self_spin .reshape (1 , - 1 , 1 ).repeat (nframes , axis = 0 )
157- # self spin + real neighbor + virtual neighbor
161+ self_real = (
162+ np .arange (0 , nloc , dtype = nlist .dtype )
163+ .reshape (1 , - 1 , 1 )
164+ .repeat (nframes , axis = 0 )
165+ )
166+ self_spin = self_real + nall
167+ # real atom's neighbors: self spin + real neighbor + virtual neighbor
168+ # nf x nloc x (1 + nnei + nnei)
169+ real_nlist = np .concatenate ([self_spin , nlist , nlist_shift ], axis = - 1 )
170+ # spin atom's neighbors: real + real neighbor + virtual neighbor
158171 # nf x nloc x (1 + nnei + nnei)
159- extended_nlist = np .concatenate ([self_spin , nlist , nlist_shift ], axis = - 1 )
172+ spin_nlist = np .concatenate ([self_real , nlist , nlist_shift ], axis = - 1 )
160173 # nf x (nloc + nloc) x (1 + nnei + nnei)
161- extended_nlist = np .concatenate (
162- [extended_nlist , - 1 * np .ones_like (extended_nlist )], axis = - 2
163- )
174+ extended_nlist = np .concatenate ([real_nlist , spin_nlist ], axis = - 2 )
164175 # update the index for switch
165176 first_part_index = (nloc <= extended_nlist ) & (extended_nlist < nall )
166177 second_part_index = (nall <= extended_nlist ) & (extended_nlist < (nall + nloc ))
@@ -187,12 +198,40 @@ def concat_switch_virtual(extended_tensor, extended_tensor_virtual, nloc: int):
187198 extended_tensor_updated [:, nloc + nall :] = extended_tensor_virtual [:, nloc :]
188199 return extended_tensor_updated .reshape (out_shape )
189200
201+ @staticmethod
202+ def expand_aparam (aparam , nloc : int ):
203+ """Expand the atom parameters for virtual atoms if necessary."""
204+ nframes , natom , numb_aparam = aparam .shape
205+ if natom == nloc : # good
206+ pass
207+ elif natom < nloc : # for spin with virtual atoms
208+ aparam = np .concatenate (
209+ [
210+ aparam ,
211+ np .zeros (
212+ [nframes , nloc - natom , numb_aparam ],
213+ dtype = aparam .dtype ,
214+ ),
215+ ],
216+ axis = 1 ,
217+ )
218+ else :
219+ raise ValueError (
220+ f"get an input aparam with { aparam .shape [1 ]} inputs, " ,
221+ f"which is larger than { nloc } atoms." ,
222+ )
223+ return aparam
224+
190225 def get_type_map (self ) -> List [str ]:
191226 """Get the type map."""
192227 tmap = self .backbone_model .get_type_map ()
193228 ntypes = len (tmap ) // 2 # ignore the virtual type
194229 return tmap [:ntypes ]
195230
231+ def get_ntypes (self ):
232+ """Returns the number of element types."""
233+ return len (self .get_type_map ())
234+
196235 def get_rcut (self ):
197236 """Get the cut-off radius."""
198237 return self .backbone_model .get_rcut ()
@@ -251,6 +290,16 @@ def has_spin() -> bool:
251290 """Returns whether it has spin input and output."""
252291 return True
253292
293+ def model_output_def (self ):
294+ """Get the output def for the model."""
295+ model_output_type = self .backbone_model .model_output_type ()
296+ if "mask" in model_output_type :
297+ model_output_type .pop (model_output_type .index ("mask" ))
298+ var_name = model_output_type [0 ]
299+ backbone_model_atomic_output_def = self .backbone_model .atomic_output_def ()
300+ backbone_model_atomic_output_def [var_name ].magnetic = True
301+ return ModelOutputDef (backbone_model_atomic_output_def )
302+
254303 def __getattr__ (self , name ):
255304 """Get attribute from the wrapped model."""
256305 if name in self .__dict__ :
@@ -313,8 +362,12 @@ def call(
313362 The keys are defined by the `ModelOutputDef`.
314363
315364 """
316- nframes , nloc = coord .shape [:2 ]
365+ nframes , nloc = atype .shape [:2 ]
366+ coord = coord .reshape (nframes , nloc , 3 )
367+ spin = spin .reshape (nframes , nloc , 3 )
317368 coord_updated , atype_updated = self .process_spin_input (coord , atype , spin )
369+ if aparam is not None :
370+ aparam = self .expand_aparam (aparam , nloc * 2 )
318371 model_predict = self .backbone_model .call (
319372 coord_updated ,
320373 atype_updated ,
@@ -383,6 +436,8 @@ def call_lower(
383436 ) = self .process_spin_input_lower (
384437 extended_coord , extended_atype , extended_spin , nlist , mapping = mapping
385438 )
439+ if aparam is not None :
440+ aparam = self .expand_aparam (aparam , nloc * 2 )
386441 model_predict = self .backbone_model .call_lower (
387442 extended_coord_updated ,
388443 extended_atype_updated ,
@@ -401,3 +456,5 @@ def call_lower(
401456 )[0 ]
402457 # for now omit the grad output
403458 return model_predict
459+
460+ forward_lower = call_lower
0 commit comments