Skip to content

Commit a7b2e2b

Browse files
KiddoZhuparthchadhaSahilJain314
authored andcommitted
fix: change format messages to out of place (#77)
Signed-off-by: KiddoZhu <zhaochengz@nvidia.com> Signed-off-by: Sahil Jain <sahilj@nvidia.com> Co-authored-by: Parth Chadha <pchadha@nvidia.com> Co-authored-by: Sahil Jain <48468750+SahilJain314@users.noreply.github.com> Signed-off-by: Yi-Fu Wu <yifu.wu@gmail.com>
1 parent e27b5fd commit a7b2e2b

2 files changed

Lines changed: 10 additions & 7 deletions

File tree

nemo_reinforcer/algorithms/loss_functions.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -58,7 +58,7 @@ class ClippedPGLossFn(LossFunction):
5858
- r_t(θ) = π_θ(a_t|s_t) / π_θ_old(a_t|s_t) is the probability ratio
5959
- A_t is the advantage estimate
6060
- ε is the clip parameter (ratio_eps)
61-
- As proposed in the DAPO paper (https://arxiv.org/pdf/2503.14476),
61+
- As proposed in the DAPO paper (https://arxiv.org/pdf/2503.14476),
6262
we allow setting a distinct minimum and maximum value for the clip parameter (set to the same value for PPO/GRPO/etc.)
6363
- ratio_eps_min: minimum value for the clip parameter
6464
- ratio_eps_max: maximum value for the clip parameter

nemo_reinforcer/data/llm_message_utils.py

Lines changed: 9 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -353,14 +353,13 @@ def get_formatted_message_log(
353353
Returns:
354354
The message log with updated 'token_ids' and 'content' fields.
355355
"""
356-
cu_message = []
356+
new_message_log = []
357357
prev_formatted_message = ""
358358
template = task_data_spec.custom_template
359359

360360
for i, message in enumerate(message_log):
361-
cu_message.append(message.copy())
362361
formatted_message = tokenizer.apply_chat_template(
363-
cu_message,
362+
message_log[: i + 1],
364363
chat_template=template,
365364
add_generation_prompt=False,
366365
tokenize=False,
@@ -383,13 +382,17 @@ def get_formatted_message_log(
383382
message_chunk = message_chunk.rstrip("\n")
384383
if not message_chunk.endswith(tokenizer.eos_token):
385384
message_chunk += tokenizer.eos_token
386-
message["token_ids"] = tokenizer(
385+
386+
new_message = message.copy()
387+
new_message["token_ids"] = tokenizer(
387388
message_chunk, return_tensors="pt", add_special_tokens=False
388389
)["input_ids"][0]
389-
message["content"] = message_chunk
390+
new_message["content"] = message_chunk
391+
new_message_log.append(new_message)
392+
390393
prev_formatted_message = formatted_message
391394

392-
return message_log
395+
return new_message_log
393396

394397

395398
def remap_dataset_keys(

0 commit comments

Comments
 (0)