Skip to content

Commit 65facbb

Browse files
Fix: pt tensor loss label name (#4587)
To address polar nan loss mentioned in #4586 <!-- This is an auto-generated comment: release notes by coderabbit.ai --> ## Summary by CodeRabbit - **Refactor** - Adjusted the internal processing order in computation routines to enhance consistency while maintaining the same overall user experience. - Updated model prediction handling to ensure compatibility in shape during statistical computations, reducing potential runtime errors. <!-- end of auto-generated comment: release notes by coderabbit.ai --> --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
1 parent 5bae5f2 commit 65facbb

2 files changed

Lines changed: 7 additions & 7 deletions

File tree

deepmd/pt/train/training.py

Lines changed: 4 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1253,13 +1253,11 @@ def get_loss(loss_params, start_lr, _ntypes, _model):
12531253
if "mask" in model_output_type:
12541254
model_output_type.pop(model_output_type.index("mask"))
12551255
tensor_name = model_output_type[0]
1256-
loss_params["tensor_name"] = tensor_name
12571256
loss_params["tensor_size"] = _model.model_output_def()[tensor_name].output_size
1258-
label_name = tensor_name
1259-
if label_name == "polarizability":
1260-
label_name = "polar"
1261-
loss_params["label_name"] = label_name
1262-
loss_params["tensor_name"] = label_name
1257+
loss_params["label_name"] = tensor_name
1258+
if tensor_name == "polarizability":
1259+
tensor_name = "polar"
1260+
loss_params["tensor_name"] = tensor_name
12631261
return TensorLoss(**loss_params)
12641262
elif loss_type == "property":
12651263
task_dim = _model.get_task_dim()

deepmd/pt/utils/stat.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -477,7 +477,9 @@ def compute_output_stats_global(
477477
# subtract the model bias and output the delta bias
478478

479479
stats_input = {
480-
kk: merged_output[kk] - model_pred[kk] for kk in keys if kk in merged_output
480+
kk: merged_output[kk] - model_pred[kk].reshape(merged_output[kk].shape)
481+
for kk in keys
482+
if kk in merged_output
481483
}
482484

483485
bias_atom_e = {}

0 commit comments

Comments
 (0)