Conversation
fast_llm/functional/cross_entropy.py
Outdated
| per_token_loss = torch.nn.functional.cross_entropy( | ||
| logits_ if logits_scale_factor == 1 else logits_ * logits_scale_factor, target, reduction="none" | ||
| ) | ||
| loss = (per_token_loss * loss_mask).sum() / loss_mask.sum() |
There was a problem hiding this comment.
This can result in nans if loss_mask.sum() is 0, which can happen actually in practice in the context of reasoning SFT where prompts can be very long or when we to TP and split across sequence length dimension
So maybe better to check something like:
if mask_sum > 0: # can happen for inputs containing only prompts?
loss = (loss_per_token * loss_mask).sum() / mask_sum
else:
loss = (loss_per_token * 0.0).mean() # preserve grads
|
As discussed with @oleksost , to finish the fix we'd also need to properly reduce the loss across micro-sequences, taking into account the sum of the Now, on second thought: The question is whether we want an average of the loss over samples, or over tokens. |
✨ Description
triton_cross_entropy_from_distribution_forward_backward_kernelCloses #344
🔍 Type of change
Select all that apply:
📝 Changes
List the key changes introduced in this PR:
✅ Checklist
Make sure the following tasks are completed before submitting the PR:
General
Dependencies and Configuration
Testing
Performance Impact
📊 Performance Impact Details
If there is any impact on performance, describe it and provide benchmark results, if applicable:
🗒️ Additional Notes
Include any additional context, information, or considerations here, such as known issues, follow-up tasks, or backward compatibility concerns.