|
26 | 26 | save_checkpoint, |
27 | 27 | ModelState, |
28 | 28 | OptimizerState, |
| 29 | + convert_dcp_to_hf, |
29 | 30 | ) |
| 31 | +from tests.unit.test_utils import simple_loss |
30 | 32 |
|
31 | 33 | # Define basic test config |
32 | 34 | simple_policy_config = { |
33 | 35 | "model_name": "meta-llama/Llama-3.2-1B", # "hf-internal-testing/tiny-random-Gemma3ForCausalLM", |
34 | 36 | "tokenizer_name": "meta-llama/Llama-3.2-1B", # "hf-internal-testing/tiny-random-Gemma3ForCausalLM", |
35 | | - "train_global_batch_size": 32, |
| 37 | + "train_global_batch_size": 4, |
36 | 38 | "train_micro_batch_size": 1, |
37 | 39 | "logprob_batch_size": 1, |
38 | 40 | "max_total_sequence_length": 1024, |
39 | 41 | "precision": "float32", |
| 42 | + "optimizer": { |
| 43 | + "name": "torch.optim.AdamW", |
| 44 | + "kwargs": { |
| 45 | + "lr": 5e-6, |
| 46 | + "weight_decay": 0.01, |
| 47 | + "betas": [0.9, 0.999], |
| 48 | + "eps": 1e-8, |
| 49 | + }, |
| 50 | + }, |
40 | 51 | } |
41 | 52 |
|
42 | 53 |
|
@@ -85,12 +96,14 @@ def tokenizer(): |
85 | 96 | @pytest.fixture(scope="function") |
86 | 97 | def policy(cluster, tokenizer): |
87 | 98 | """Initialize the policy.""" |
88 | | - return HfPolicy( |
| 99 | + policy = HfPolicy( |
89 | 100 | cluster=cluster, |
90 | 101 | config=simple_policy_config, |
91 | | - init_optimizer=False, |
| 102 | + init_optimizer=True, |
92 | 103 | init_reference_model=False, |
93 | 104 | ) |
| 105 | + yield policy |
| 106 | + policy.worker_group.shutdown() |
94 | 107 |
|
95 | 108 |
|
96 | 109 | def get_dummy_state_dict(state_dict, dummy_dict={}): |
@@ -118,6 +131,15 @@ def check_dict_equality(dict1, dict2): |
118 | 131 | assert dict1[k] == dict2[k] |
119 | 132 |
|
120 | 133 |
|
| 134 | +def assert_recursive_dict_different(dict1, dict2): |
| 135 | + """Recursively assert that two dictionaries are different""" |
| 136 | + try: |
| 137 | + check_dict_equality(dict1, dict2) |
| 138 | + except AssertionError: |
| 139 | + return |
| 140 | + raise AssertionError("Dictionaries are equal") |
| 141 | + |
| 142 | + |
121 | 143 | def test_model_state(mock_experiment): |
122 | 144 | test_model, _, _ = mock_experiment |
123 | 145 | model_state = ModelState(test_model) |
@@ -275,14 +297,84 @@ def test_save_and_load_hf_checkpoint(policy): |
275 | 297 | "model.safetensors.index.json", |
276 | 298 | } |
277 | 299 |
|
278 | | - coverted_model = AutoModelForCausalLM.from_pretrained( |
| 300 | + converted_model = AutoModelForCausalLM.from_pretrained( |
279 | 301 | os.path.join(tmp_dir, "test_hf_and_dcp-hf") |
280 | 302 | ) |
281 | 303 | original_model = AutoModelForCausalLM.from_pretrained( |
282 | 304 | simple_policy_config["model_name"] |
283 | 305 | ) |
284 | 306 |
|
285 | 307 | ## make sure converted model matches the original |
286 | | - check_dict_equality(coverted_model.state_dict(), original_model.state_dict()) |
| 308 | + check_dict_equality(converted_model.state_dict(), original_model.state_dict()) |
287 | 309 |
|
288 | | - policy.worker_group.shutdown() |
| 310 | + |
| 311 | +def test_convert_dcp_to_hf(policy): |
| 312 | + ## warm up with a forward pass |
| 313 | + ## this is needed before saving a checkpoint because FSDP does some lazy initialization |
| 314 | + input_ids = torch.randint(0, 16000, (4, 128)) # 4 sequences, each of length 128 |
| 315 | + attention_mask = torch.ones(4, 128) |
| 316 | + input_lengths = attention_mask.sum(dim=1).to(torch.int32) |
| 317 | + dummy_fwd_dict = BatchedDataDict( |
| 318 | + { |
| 319 | + "input_ids": input_ids, |
| 320 | + "input_lengths": input_lengths, |
| 321 | + "attention_mask": attention_mask, |
| 322 | + "labels": torch.randint(0, 16000, (4, 128)), |
| 323 | + } |
| 324 | + ) |
| 325 | + policy.train(dummy_fwd_dict, simple_loss) |
| 326 | + |
| 327 | + with TemporaryDirectory() as tmp_dir: |
| 328 | + policy.save_checkpoint( |
| 329 | + os.path.join(tmp_dir, "test_hf_and_dcp"), |
| 330 | + save_hf=True, |
| 331 | + save_torch_dist=True, |
| 332 | + ) |
| 333 | + |
| 334 | + ## make sure we save both HF and DCP checkpoints |
| 335 | + assert set(os.listdir(os.path.join(tmp_dir, "test_hf_and_dcp"))) == { |
| 336 | + "__0_0.distcp", |
| 337 | + "__1_0.distcp", |
| 338 | + ".metadata", |
| 339 | + } |
| 340 | + ## 1B model has two shards |
| 341 | + assert set(os.listdir(os.path.join(tmp_dir, "test_hf_and_dcp-hf"))) == { |
| 342 | + "config.json", |
| 343 | + "generation_config.json", |
| 344 | + "model-00001-of-00002.safetensors", |
| 345 | + "model-00002-of-00002.safetensors", |
| 346 | + "model.safetensors.index.json", |
| 347 | + } |
| 348 | + |
| 349 | + offline_converted_model_path = convert_dcp_to_hf( |
| 350 | + os.path.join(tmp_dir, "test_hf_and_dcp"), |
| 351 | + os.path.join(tmp_dir, "test_hf_and_dcp-hf-offline"), |
| 352 | + simple_policy_config["model_name"], |
| 353 | + # TODO: After the following PR gets merged: |
| 354 | + # https://github.com/NVIDIA/reinforcer/pull/148/files |
| 355 | + # tokenizer should be copied from policy/tokenizer/* instead of relying on the model name |
| 356 | + # We can expose a arg at the top level --tokenizer_path to plumb that through. |
| 357 | + # This is more stable than relying on the current NeMo-RL get_tokenizer() which can |
| 358 | + # change release to release. |
| 359 | + simple_policy_config["model_name"], |
| 360 | + ) |
| 361 | + |
| 362 | + offline_converted_model = AutoModelForCausalLM.from_pretrained( |
| 363 | + offline_converted_model_path |
| 364 | + ) |
| 365 | + |
| 366 | + online_converted_model = AutoModelForCausalLM.from_pretrained( |
| 367 | + os.path.join(tmp_dir, "test_hf_and_dcp-hf") |
| 368 | + ) |
| 369 | + original_model = AutoModelForCausalLM.from_pretrained( |
| 370 | + simple_policy_config["model_name"] |
| 371 | + ) |
| 372 | + |
| 373 | + ## make sure both conversions results in the same state dict |
| 374 | + check_dict_equality( |
| 375 | + online_converted_model.state_dict(), offline_converted_model.state_dict() |
| 376 | + ) |
| 377 | + # Ensure the offline one is different from the original |
| 378 | + assert_recursive_dict_different( |
| 379 | + offline_converted_model.state_dict(), original_model.state_dict() |
| 380 | + ) |
0 commit comments