Skip to content

Commit 37bee25

Browse files
authored
fix(pt): add finetune_head to argcheck (#3967)
Add `finetune_head` to argcheck. <!-- This is an auto-generated comment: release notes by coderabbit.ai --> ## Summary by CodeRabbit - **New Features** - Introduced a new `finetune_head` argument for specifying the fitting net during multi-task fine-tuning, with optional random initialization if not set. - **Bug Fixes** - Improved handling for specific conditions by automatically removing the "finetune_head" key from the configuration. - **Tests** - Updated multitask training and finetuning tests to include new configuration manipulations. - Removed the `_comment` field from test configuration files to ensure cleaner test setups. <!-- end of auto-generated comment: release notes by coderabbit.ai -->
1 parent 698ff6f commit 37bee25

3 files changed

Lines changed: 20 additions & 1 deletion

File tree

deepmd/utils/argcheck.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1533,6 +1533,10 @@ def model_args(exclude_hybrid=False):
15331533
doc_spin = "The settings for systems with spin."
15341534
doc_atom_exclude_types = "Exclude the atomic contribution of the listed atom types"
15351535
doc_pair_exclude_types = "The atom pairs of the listed types are not treated to be neighbors, i.e. they do not see each other."
1536+
doc_finetune_head = (
1537+
"The chosen fitting net to fine-tune on, when doing multi-task fine-tuning. "
1538+
"If not set or set to 'RANDOM', the fitting net will be randomly initialized."
1539+
)
15361540

15371541
hybrid_models = []
15381542
if not exclude_hybrid:
@@ -1629,6 +1633,12 @@ def model_args(exclude_hybrid=False):
16291633
fold_subdoc=True,
16301634
),
16311635
Argument("spin", dict, spin_args(), [], optional=True, doc=doc_spin),
1636+
Argument(
1637+
"finetune_head",
1638+
str,
1639+
optional=True,
1640+
doc=doc_only_pt_supported + doc_finetune_head,
1641+
),
16321642
],
16331643
[
16341644
Variant(

source/tests/pt/model/water/multitask.json

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -68,7 +68,6 @@
6868
"_comment": "that's all"
6969
},
7070
"loss_dict": {
71-
"_comment": " that's all",
7271
"model_1": {
7372
"type": "ener",
7473
"start_pref_e": 0.02,

source/tests/pt/test_multitask.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,12 @@
2121
from deepmd.pt.utils.multi_task import (
2222
preprocess_shared_params,
2323
)
24+
from deepmd.utils.argcheck import (
25+
normalize,
26+
)
27+
from deepmd.utils.compat import (
28+
update_deepmd_input,
29+
)
2430

2531
from .model.test_permutation import (
2632
model_dpa1,
@@ -39,6 +45,8 @@ def setUpModule():
3945
class MultiTaskTrainTest:
4046
def test_multitask_train(self):
4147
# test multitask training
48+
self.config = update_deepmd_input(self.config, warning=True)
49+
self.config = normalize(self.config, multi_task=True)
4250
trainer = get_trainer(deepcopy(self.config), shared_links=self.shared_links)
4351
trainer.run()
4452
# check model keys
@@ -124,6 +132,8 @@ def test_multitask_train(self):
124132
finetune_model,
125133
self.origin_config["model"],
126134
)
135+
self.origin_config = update_deepmd_input(self.origin_config, warning=True)
136+
self.origin_config = normalize(self.origin_config, multi_task=True)
127137
trainer_finetune = get_trainer(
128138
deepcopy(self.origin_config),
129139
finetune_model=finetune_model,

0 commit comments

Comments
 (0)