@@ -134,7 +134,7 @@ def test(
134134 )
135135
136136 if isinstance (dp , DeepPot ):
137- #err, find_energy, find_force, find_virial = test_ener(
137+ # err, find_energy, find_force, find_virial = test_ener(
138138 err = test_ener (
139139 dp ,
140140 data ,
@@ -559,37 +559,33 @@ def test_ener(
559559 append = append_detail ,
560560 )
561561 if not out_put_spin :
562- return (
563- {
564- "mae_e" : (mae_e , energy .size ),
565- "mae_ea" : (mae_ea , energy .size ),
566- "mae_f" : (mae_f , force .size ),
567- "mae_v" : (mae_v , virial .size ),
568- "mae_va" : (mae_va , virial .size ),
569- "rmse_e" : (rmse_e , energy .size ),
570- "rmse_ea" : (rmse_ea , energy .size ),
571- "rmse_f" : (rmse_f , force .size ),
572- "rmse_v" : (rmse_v , virial .size ),
573- "rmse_va" : (rmse_va , virial .size ),
574- }#find_energy,find_force,find_virial,
575- )
562+ return {
563+ "mae_e" : (mae_e , energy .size ),
564+ "mae_ea" : (mae_ea , energy .size ),
565+ "mae_f" : (mae_f , force .size ),
566+ "mae_v" : (mae_v , virial .size ),
567+ "mae_va" : (mae_va , virial .size ),
568+ "rmse_e" : (rmse_e , energy .size ),
569+ "rmse_ea" : (rmse_ea , energy .size ),
570+ "rmse_f" : (rmse_f , force .size ),
571+ "rmse_v" : (rmse_v , virial .size ),
572+ "rmse_va" : (rmse_va , virial .size ),
573+ } # find_energy,find_force,find_virial,
576574 else :
577- return (
578- {
579- "mae_e" : (mae_e , energy .size ),
580- "mae_ea" : (mae_ea , energy .size ),
581- "mae_fr" : (mae_fr , force_r .size ),
582- "mae_fm" : (mae_fm , force_m .size ),
583- "mae_v" : (mae_v , virial .size ),
584- "mae_va" : (mae_va , virial .size ),
585- "rmse_e" : (rmse_e , energy .size ),
586- "rmse_ea" : (rmse_ea , energy .size ),
587- "rmse_fr" : (rmse_fr , force_r .size ),
588- "rmse_fm" : (rmse_fm , force_m .size ),
589- "rmse_v" : (rmse_v , virial .size ),
590- "rmse_va" : (rmse_va , virial .size ),
591- }#find_energy,find_force,find_virial,
592- )
575+ return {
576+ "mae_e" : (mae_e , energy .size ),
577+ "mae_ea" : (mae_ea , energy .size ),
578+ "mae_fr" : (mae_fr , force_r .size ),
579+ "mae_fm" : (mae_fm , force_m .size ),
580+ "mae_v" : (mae_v , virial .size ),
581+ "mae_va" : (mae_va , virial .size ),
582+ "rmse_e" : (rmse_e , energy .size ),
583+ "rmse_ea" : (rmse_ea , energy .size ),
584+ "rmse_fr" : (rmse_fr , force_r .size ),
585+ "rmse_fm" : (rmse_fm , force_m .size ),
586+ "rmse_v" : (rmse_v , virial .size ),
587+ "rmse_va" : (rmse_va , virial .size ),
588+ } # find_energy,find_force,find_virial,
593589
594590
595591def print_ener_sys_avg (avg : dict [str , float ]) -> None :
0 commit comments