Skip to content

Commit 9ac4e62

Browse files
authored
fix: convert DCP to HF script works without ray cluster (#185)
Signed-off-by: Terry Kong <terryk@nvidia.com>
1 parent 8213014 commit 9ac4e62

3 files changed

Lines changed: 175 additions & 45 deletions

File tree

examples/convert_dcp_to_hf.py

Lines changed: 18 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -13,12 +13,10 @@
1313
# limitations under the License.
1414

1515
import argparse
16-
import os
1716
import json
18-
19-
from nemo_reinforcer.distributed.virtual_cluster import init_ray, RayVirtualCluster
20-
from nemo_reinforcer.models.policy.hf_policy import HfPolicy
21-
from nemo_reinforcer.utils.config import load_config
17+
import os
18+
import torch
19+
from nemo_reinforcer.utils.native_checkpoint import convert_dcp_to_hf
2220

2321

2422
def parse_args():
@@ -51,41 +49,22 @@ def main():
5149
with open(args.config, "r") as f:
5250
config = json.load(f)
5351

54-
dcp_ckpt = args.dcp_ckpt_path
55-
hf_ckpt = args.hf_ckpt_path
56-
57-
# Extract individual configs for easier access
58-
policy_config = config["policy"]
59-
cluster_config = config["cluster"]
60-
61-
init_ray()
62-
63-
cluster = RayVirtualCluster(
64-
name="convert_cluster",
65-
bundle_ct_per_node_list=[cluster_config["gpus_per_node"]]
66-
* cluster_config["num_nodes"],
67-
use_gpus=True,
68-
num_gpus_per_node=cluster_config["gpus_per_node"],
69-
max_colocated_worker_groups=1,
52+
model_name_or_path = config["policy"]["model_name"]
53+
# TODO: After the following PR gets merged:
54+
# https://github.com/NVIDIA/reinforcer/pull/148/files
55+
# tokenizer should be copied from policy/tokenizer/* instead of relying on the model name
56+
# We can expose a arg at the top level --tokenizer_path to plumb that through.
57+
# This is more stable than relying on the current NeMo-RL get_tokenizer() which can
58+
# change release to release.
59+
tokenizer_name_or_path = config["policy"]["model_name"]
60+
61+
hf_ckpt = convert_dcp_to_hf(
62+
dcp_ckpt_path=args.dcp_ckpt_path,
63+
hf_ckpt_path=args.hf_ckpt_path,
64+
model_name_or_path=model_name_or_path,
65+
tokenizer_name_or_path=tokenizer_name_or_path,
7066
)
71-
72-
policy = HfPolicy(
73-
cluster=cluster,
74-
config=policy_config,
75-
weights_path=dcp_ckpt,
76-
init_optimizer=False,
77-
)
78-
79-
policy.save_checkpoint(
80-
weights_path=os.path.abspath(hf_ckpt),
81-
save_hf=True,
82-
save_torch_dist=False,
83-
)
84-
85-
print(f"Saved HF checkpoint to: {hf_ckpt}-hf")
86-
87-
cluster.shutdown()
88-
policy.worker_group.shutdown()
67+
print(f"Saved HF checkpoint to: {hf_ckpt}")
8968

9069

9170
if __name__ == "__main__":

nemo_reinforcer/utils/native_checkpoint.py

Lines changed: 59 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
from typing import Any, Optional
2020
import torch
2121

22+
from transformers import AutoConfig, AutoTokenizer
2223
import torch.distributed.checkpoint as dcp
2324
from torch.distributed.checkpoint.stateful import Stateful
2425
from torch.distributed.checkpoint.state_dict import (
@@ -27,6 +28,7 @@
2728
get_optimizer_state_dict,
2829
set_optimizer_state_dict,
2930
)
31+
from torch.distributed.checkpoint.format_utils import dcp_to_torch_save
3032

3133

3234
## modified from pytorch tutorial https://pytorch.org/tutorials/recipes/distributed_checkpoint_recipe.html
@@ -202,3 +204,60 @@ def load_checkpoint(
202204
print(f"Loading optimizer from {optimizer_path}")
203205
optimizer_state_dict = {"optim": OptimizerState(model, optimizer, scheduler)}
204206
dcp.load(state_dict=optimizer_state_dict, checkpoint_id=optimizer_path)
207+
208+
209+
def convert_dcp_to_hf(
210+
dcp_ckpt_path: str,
211+
hf_ckpt_path: str,
212+
model_name_or_path: str,
213+
tokenizer_name_or_path: str,
214+
overwrite: bool = False,
215+
):
216+
"""Convert a Torch DCP checkpoint to a Hugging Face checkpoint.
217+
218+
This is not an optimized utility. If checkpoint is too large, consider saving DCP during training
219+
and using this utility to convert to HF format.
220+
221+
Args:
222+
dcp_ckpt_path (str): Path to DCP checkpoint
223+
hf_ckpt_path (str): Path to save HF checkpoint
224+
model_name_or_path (str): Model name or path for config
225+
tokenizer_name_or_path (str, optional): Tokenizer name or path.
226+
Defaults to model_name_or_path if None.
227+
overwrite (bool, optional): Whether to overwrite existing checkpoint. Defaults to False.
228+
229+
Returns:
230+
str: Path to the saved HF checkpoint
231+
232+
Raises:
233+
FileExistsError: If HF checkpoint already exists and overwrite is False
234+
"""
235+
if os.path.exists(hf_ckpt_path) and not overwrite:
236+
raise FileExistsError(
237+
f"HF checkpoint already exists at {hf_ckpt_path}. Delete it to run or set overwrite=True."
238+
)
239+
240+
os.makedirs(hf_ckpt_path, exist_ok=True)
241+
weights_path = os.path.join(hf_ckpt_path, "pytorch_model.bin")
242+
dcp_to_torch_save(dcp_ckpt_path, weights_path)
243+
244+
# Need to reload and save b/c the state dict is scoped inside the model key {"model": actual_state_dict}
245+
state_dict = torch.load(weights_path)
246+
assert set(state_dict.keys()) == {"model"}, (
247+
f"We expect that the state dict only has the top level model key, but found: {state_dict.keys()}"
248+
)
249+
torch.save(state_dict["model"], weights_path)
250+
251+
config = AutoConfig.from_pretrained(model_name_or_path)
252+
config.save_pretrained(hf_ckpt_path)
253+
254+
# TODO: After the following PR gets merged:
255+
# https://github.com/NVIDIA/reinforcer/pull/148/files
256+
# tokenizer should be copied from policy/tokenizer/* instead of relying on the model name
257+
# We can expose a arg at the top level --tokenizer_path to plumb that through.
258+
# This is more stable than relying on the current NeMo-RL get_tokenizer() which can
259+
# change release to release.
260+
tokenizer = AutoTokenizer.from_pretrained(tokenizer_name_or_path)
261+
tokenizer.save_pretrained(hf_ckpt_path)
262+
263+
return hf_ckpt_path

tests/unit/utils/test_native_checkpoint.py

Lines changed: 98 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -26,17 +26,28 @@
2626
save_checkpoint,
2727
ModelState,
2828
OptimizerState,
29+
convert_dcp_to_hf,
2930
)
31+
from tests.unit.test_utils import simple_loss
3032

3133
# Define basic test config
3234
simple_policy_config = {
3335
"model_name": "meta-llama/Llama-3.2-1B", # "hf-internal-testing/tiny-random-Gemma3ForCausalLM",
3436
"tokenizer_name": "meta-llama/Llama-3.2-1B", # "hf-internal-testing/tiny-random-Gemma3ForCausalLM",
35-
"train_global_batch_size": 32,
37+
"train_global_batch_size": 4,
3638
"train_micro_batch_size": 1,
3739
"logprob_batch_size": 1,
3840
"max_total_sequence_length": 1024,
3941
"precision": "float32",
42+
"optimizer": {
43+
"name": "torch.optim.AdamW",
44+
"kwargs": {
45+
"lr": 5e-6,
46+
"weight_decay": 0.01,
47+
"betas": [0.9, 0.999],
48+
"eps": 1e-8,
49+
},
50+
},
4051
}
4152

4253

@@ -85,12 +96,14 @@ def tokenizer():
8596
@pytest.fixture(scope="function")
8697
def policy(cluster, tokenizer):
8798
"""Initialize the policy."""
88-
return HfPolicy(
99+
policy = HfPolicy(
89100
cluster=cluster,
90101
config=simple_policy_config,
91-
init_optimizer=False,
102+
init_optimizer=True,
92103
init_reference_model=False,
93104
)
105+
yield policy
106+
policy.worker_group.shutdown()
94107

95108

96109
def get_dummy_state_dict(state_dict, dummy_dict={}):
@@ -118,6 +131,15 @@ def check_dict_equality(dict1, dict2):
118131
assert dict1[k] == dict2[k]
119132

120133

134+
def assert_recursive_dict_different(dict1, dict2):
135+
"""Recursively assert that two dictionaries are different"""
136+
try:
137+
check_dict_equality(dict1, dict2)
138+
except AssertionError:
139+
return
140+
raise AssertionError("Dictionaries are equal")
141+
142+
121143
def test_model_state(mock_experiment):
122144
test_model, _, _ = mock_experiment
123145
model_state = ModelState(test_model)
@@ -275,14 +297,84 @@ def test_save_and_load_hf_checkpoint(policy):
275297
"model.safetensors.index.json",
276298
}
277299

278-
coverted_model = AutoModelForCausalLM.from_pretrained(
300+
converted_model = AutoModelForCausalLM.from_pretrained(
279301
os.path.join(tmp_dir, "test_hf_and_dcp-hf")
280302
)
281303
original_model = AutoModelForCausalLM.from_pretrained(
282304
simple_policy_config["model_name"]
283305
)
284306

285307
## make sure converted model matches the original
286-
check_dict_equality(coverted_model.state_dict(), original_model.state_dict())
308+
check_dict_equality(converted_model.state_dict(), original_model.state_dict())
287309

288-
policy.worker_group.shutdown()
310+
311+
def test_convert_dcp_to_hf(policy):
312+
## warm up with a forward pass
313+
## this is needed before saving a checkpoint because FSDP does some lazy initialization
314+
input_ids = torch.randint(0, 16000, (4, 128)) # 4 sequences, each of length 128
315+
attention_mask = torch.ones(4, 128)
316+
input_lengths = attention_mask.sum(dim=1).to(torch.int32)
317+
dummy_fwd_dict = BatchedDataDict(
318+
{
319+
"input_ids": input_ids,
320+
"input_lengths": input_lengths,
321+
"attention_mask": attention_mask,
322+
"labels": torch.randint(0, 16000, (4, 128)),
323+
}
324+
)
325+
policy.train(dummy_fwd_dict, simple_loss)
326+
327+
with TemporaryDirectory() as tmp_dir:
328+
policy.save_checkpoint(
329+
os.path.join(tmp_dir, "test_hf_and_dcp"),
330+
save_hf=True,
331+
save_torch_dist=True,
332+
)
333+
334+
## make sure we save both HF and DCP checkpoints
335+
assert set(os.listdir(os.path.join(tmp_dir, "test_hf_and_dcp"))) == {
336+
"__0_0.distcp",
337+
"__1_0.distcp",
338+
".metadata",
339+
}
340+
## 1B model has two shards
341+
assert set(os.listdir(os.path.join(tmp_dir, "test_hf_and_dcp-hf"))) == {
342+
"config.json",
343+
"generation_config.json",
344+
"model-00001-of-00002.safetensors",
345+
"model-00002-of-00002.safetensors",
346+
"model.safetensors.index.json",
347+
}
348+
349+
offline_converted_model_path = convert_dcp_to_hf(
350+
os.path.join(tmp_dir, "test_hf_and_dcp"),
351+
os.path.join(tmp_dir, "test_hf_and_dcp-hf-offline"),
352+
simple_policy_config["model_name"],
353+
# TODO: After the following PR gets merged:
354+
# https://github.com/NVIDIA/reinforcer/pull/148/files
355+
# tokenizer should be copied from policy/tokenizer/* instead of relying on the model name
356+
# We can expose a arg at the top level --tokenizer_path to plumb that through.
357+
# This is more stable than relying on the current NeMo-RL get_tokenizer() which can
358+
# change release to release.
359+
simple_policy_config["model_name"],
360+
)
361+
362+
offline_converted_model = AutoModelForCausalLM.from_pretrained(
363+
offline_converted_model_path
364+
)
365+
366+
online_converted_model = AutoModelForCausalLM.from_pretrained(
367+
os.path.join(tmp_dir, "test_hf_and_dcp-hf")
368+
)
369+
original_model = AutoModelForCausalLM.from_pretrained(
370+
simple_policy_config["model_name"]
371+
)
372+
373+
## make sure both conversions results in the same state dict
374+
check_dict_equality(
375+
online_converted_model.state_dict(), offline_converted_model.state_dict()
376+
)
377+
# Ensure the offline one is different from the original
378+
assert_recursive_dict_different(
379+
offline_converted_model.state_dict(), original_model.state_dict()
380+
)

0 commit comments

Comments
 (0)