Skip to content

Commit 2a32c87

Browse files
authored
Move model deviation and ase calculator to deepmd_utils (#3173)
..., so they can benifit from multiple-backend DeepPot. Update docs. --------- Signed-off-by: Jinzhe Zeng <jinzhe.zeng@rutgers.edu>
1 parent 4e11233 commit 2a32c87

7 files changed

Lines changed: 680 additions & 656 deletions

File tree

deepmd/calculator.py

Lines changed: 5 additions & 141 deletions
Original file line numberDiff line numberDiff line change
@@ -1,144 +1,8 @@
11
# SPDX-License-Identifier: LGPL-3.0-or-later
2-
"""ASE calculator interface module."""
3-
4-
from pathlib import (
5-
Path,
6-
)
7-
from typing import (
8-
TYPE_CHECKING,
9-
ClassVar,
10-
Dict,
11-
List,
12-
Optional,
13-
Union,
14-
)
15-
16-
from ase.calculators.calculator import (
17-
Calculator,
18-
PropertyNotImplementedError,
19-
all_changes,
20-
)
21-
22-
from deepmd import (
23-
DeepPotential,
2+
from deepmd_utils.calculator import (
3+
DP,
244
)
255

26-
if TYPE_CHECKING:
27-
from ase import (
28-
Atoms,
29-
)
30-
31-
__all__ = ["DP"]
32-
33-
34-
class DP(Calculator):
35-
"""Implementation of ASE deepmd calculator.
36-
37-
Implemented propertie are `energy`, `forces` and `stress`
38-
39-
Parameters
40-
----------
41-
model : Union[str, Path]
42-
path to the model
43-
label : str, optional
44-
calculator label, by default "DP"
45-
type_dict : Dict[str, int], optional
46-
mapping of element types and their numbers, best left None and the calculator
47-
will infer this information from model, by default None
48-
neighbor_list : ase.neighborlist.NeighborList, optional
49-
The neighbor list object. If None, then build the native neighbor list.
50-
51-
Examples
52-
--------
53-
Compute potential energy
54-
55-
>>> from ase import Atoms
56-
>>> from deepmd.calculator import DP
57-
>>> water = Atoms('H2O',
58-
>>> positions=[(0.7601, 1.9270, 1),
59-
>>> (1.9575, 1, 1),
60-
>>> (1., 1., 1.)],
61-
>>> cell=[100, 100, 100],
62-
>>> calculator=DP(model="frozen_model.pb"))
63-
>>> print(water.get_potential_energy())
64-
>>> print(water.get_forces())
65-
66-
Run BFGS structure optimization
67-
68-
>>> from ase.optimize import BFGS
69-
>>> dyn = BFGS(water)
70-
>>> dyn.run(fmax=1e-6)
71-
>>> print(water.get_positions())
72-
"""
73-
74-
name = "DP"
75-
implemented_properties: ClassVar[List[str]] = [
76-
"energy",
77-
"free_energy",
78-
"forces",
79-
"virial",
80-
"stress",
81-
]
82-
83-
def __init__(
84-
self,
85-
model: Union[str, "Path"],
86-
label: str = "DP",
87-
type_dict: Optional[Dict[str, int]] = None,
88-
neighbor_list=None,
89-
**kwargs,
90-
) -> None:
91-
Calculator.__init__(self, label=label, **kwargs)
92-
self.dp = DeepPotential(str(Path(model).resolve()), neighbor_list=neighbor_list)
93-
if type_dict:
94-
self.type_dict = type_dict
95-
else:
96-
self.type_dict = dict(
97-
zip(self.dp.get_type_map(), range(self.dp.get_ntypes()))
98-
)
99-
100-
def calculate(
101-
self,
102-
atoms: Optional["Atoms"] = None,
103-
properties: List[str] = ["energy", "forces", "virial"],
104-
system_changes: List[str] = all_changes,
105-
):
106-
"""Run calculation with deepmd model.
107-
108-
Parameters
109-
----------
110-
atoms : Optional[Atoms], optional
111-
atoms object to run the calculation on, by default None
112-
properties : List[str], optional
113-
unused, only for function signature compatibility,
114-
by default ["energy", "forces", "stress"]
115-
system_changes : List[str], optional
116-
unused, only for function signature compatibility, by default all_changes
117-
"""
118-
if atoms is not None:
119-
self.atoms = atoms.copy()
120-
121-
coord = self.atoms.get_positions().reshape([1, -1])
122-
if sum(self.atoms.get_pbc()) > 0:
123-
cell = self.atoms.get_cell().reshape([1, -1])
124-
else:
125-
cell = None
126-
symbols = self.atoms.get_chemical_symbols()
127-
atype = [self.type_dict[k] for k in symbols]
128-
e, f, v = self.dp.eval(coords=coord, cells=cell, atom_types=atype)
129-
self.results["energy"] = e[0][0]
130-
# see https://gitlab.com/ase/ase/-/merge_requests/2485
131-
self.results["free_energy"] = e[0][0]
132-
self.results["forces"] = f[0]
133-
self.results["virial"] = v[0].reshape(3, 3)
134-
135-
# convert virial into stress for lattice relaxation
136-
if "stress" in properties:
137-
if sum(atoms.get_pbc()) > 0:
138-
# the usual convention (tensile stress is positive)
139-
# stress = -virial / volume
140-
stress = -0.5 * (v[0].copy() + v[0].copy().T) / atoms.get_volume()
141-
# Voigt notation
142-
self.results["stress"] = stress.flat[[0, 4, 8, 5, 2, 1]]
143-
else:
144-
raise PropertyNotImplementedError
6+
__all__ = [
7+
"DP",
8+
]

0 commit comments

Comments
 (0)