Skip to content

Language model sample#378

Merged
jlamypoirier merged 16 commits intomainfrom
jlp/lm_sample
Nov 24, 2025
Merged

Language model sample#378
jlamypoirier merged 16 commits intomainfrom
jlp/lm_sample

Conversation

@jlamypoirier
Copy link
Collaborator

@jlamypoirier jlamypoirier commented Oct 16, 2025

✨ Description

  • Replace GPTSample and GPTBatch with LanguageModelSample and LanguageModelBatch, encapsulating much of the related functionality and simplifying much of the code that use them (ex. SampledIndexedDataset, GPTBaseModel.preprocess_batch, gpt_data_collate_fn
  • Make SampledIndexedDataset agnostic of the sample type.
  • Change spans to use a standard range format, i.e. (first, last + 1) instead of (first, last). (Spans are still stored in (first, last) format in the binary file, this will be changed for the new binary format.)
  • Redo much of the code for preference spans. SampledIndexedDataset was using an entirely different code path for preference spans avoiding multi-document samples (for historical reasons I think?), I dropped it and made it use the common code path. This does mean a change in behavior (ex. multi-document samples), but that's an improvement. I have some doubts about the dpo ploss function though, I suspect the log softmax needs to be calculated separately for each document.
  • Datasets now always provide sequence lengths, and move cross_document_attention to AttentionConfig (from BatchConfig) so the attention layer decides itself whether to use varlen or not. Replace use_flash_attention with a more generic implementation enum. (See discussion in Base model interface review #370.) Add separate LanguageModelEmbeddingsConfig.cross_document_position_embeddings since absolute position embeddings may also use the sequence lengths.

@jlamypoirier jlamypoirier mentioned this pull request Oct 16, 2025
@jlamypoirier jlamypoirier marked this pull request as ready for review October 17, 2025 03:20
Copy link
Collaborator

@tscholak tscholak left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this looks ok to me. I think we should merge.

For others looking at this, the most important changes are in:

a few things that we should definitely make sure of, ideally via tests (if we haven't already):

  • span cropping, offsetting, and truncation arithmetic is correct
  • loss masking masks correctly
  • dpo works correctly

Base automatically changed from jlp/dataset_interface to main November 24, 2025 15:44
@jlamypoirier jlamypoirier merged commit 261717c into main Nov 24, 2025
2 checks passed
@jlamypoirier jlamypoirier deleted the jlp/lm_sample branch November 24, 2025 18:19
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants