Conversation
3dbcab5 to
ecd1918
Compare
| # Using varlen_mamba for variable length sequence support | ||
| RUN MAX_JOBS=2 pip install --no-build-isolation "causal-conv1d@git+https://github.com/Dao-AILab/causal-conv1d@2a288a1" | ||
| RUN MAX_JOBS=2 pip install --no-build-isolation "mamba_ssm[causal-conv1d]@git+https://github.com/state-spaces/mamba@4a8a2a2" | ||
| RUN MAX_JOBS=2 pip install --no-build-isolation "mamba_ssm[causal-conv1d]@git+https://github.com/jxiw/varlen_mamba@varlen_mamba" |
There was a problem hiding this comment.
Ignoring varlen mamba for now
There was a problem hiding this comment.
Ok, but it's mission critical
|
|
||
| _ACTIVATION_FN_MAP = { | ||
| ActivationType.gelu: lambda x: torch.nn.functional.gelu(x, approximate="tanh"), | ||
| ActivationType.gelu: torch.nn.functional.gelu, |
There was a problem hiding this comment.
We can't change this
There was a problem hiding this comment.
Can we then add another one instead?
There was a problem hiding this comment.
I think that's possible, but is it absolutely needed? The two are usually safe to swap, so if it's just to convert HF models I'd rather just convert everything to tanh gelu.
| @@ -0,0 +1,55 @@ | |||
| import typing | |||
There was a problem hiding this comment.
This is just a MLP, let's make it one.
There was a problem hiding this comment.
Only useful if this gives better speed because of fused kernels. Otherwise not worth it because we don't need to make this thing overly flexible.
There was a problem hiding this comment.
Main benefit is we get the implementation for free, no new code needed. Flexibility is just a side effect.
|
|
||
|
|
||
| @config_class() | ||
| class ImageNormalizationConfig(Config): |
There was a problem hiding this comment.
Belongs to dataset preprocessing
| @@ -0,0 +1,281 @@ | |||
| import math | |||
There was a problem hiding this comment.
Moving most of this to dataset preprocessing
| use_loss_masking_spans=self._parameters.use_loss_masking_spans, | ||
| ) | ||
| token_ids.append(sample.token_ids) | ||
| start_pos = 0 |
There was a problem hiding this comment.
All this mess is there so we know how many tokens the images will take, because preprocessing is done in the model. Moving to dataset preprocessing (before sampling) to simplify.
There was a problem hiding this comment.
As discussed.
Watch your attitude please, though.
There was a problem hiding this comment.
Sorry, commented for self-reference here
|
|
||
| batch_data = self._distributed_config.get_distributed_dim(DistributedDimNames.batch_data) | ||
| batch_dim = TensorDim(BlockDimNames.batch, micro_batch_size * batch_data.size, batch_data) | ||
| if self._config.vision_encoder.enabled: |
There was a problem hiding this comment.
Unnecessary complexity, not needed once we move preprocessing
| from fast_llm.engine.inference.runner import InferenceRunner | ||
| from fast_llm.engine.multi_stage.fast_llm_model import FastLLMModel | ||
| from fast_llm.layers.attention.config import AttentionKwargs | ||
| from fast_llm.layers.attention.preprocessing import BackupAttentionPreprocessor, FlashAttnVarlenPreprocessor |
There was a problem hiding this comment.
Too many conflicts with recent changes, redoing entirely. Moving to a separate model.
| kv_channels = "vision_kv_channels" | ||
|
|
||
|
|
||
| class VisionEncoderKwargs: |
There was a problem hiding this comment.
Unnecessary complexity. We can get those from the configs
| Assert.eq(stream.read(9), MEMMAP_INDEX_HEADER, msg=f"File: {stream.name}") | ||
| self._version = struct.unpack("<Q", stream.read(8))[0] | ||
| assert self._version in [1, 2, 3], f"Unsupported version for gpt_memmap dataset: {self._version}." | ||
| assert self._version in [1, 2, 3, 4], f"Unsupported version for gpt_memmap dataset: {self._version}." |
There was a problem hiding this comment.
This is not sustainable. Switching to a json header.
| offset=rejected_span_offset + idx * 2 * np.dtype(np.int32).itemsize, | ||
| ) | ||
| ) | ||
| offset += np.array(self._chosen_spans).nbytes + np.array(self._rejected_spans).nbytes |
There was a problem hiding this comment.
This is inefficient. We can just store the image count cumsums instead to know which images belong to each sample. (Same for spans)
There was a problem hiding this comment.
hdf could be an option...
There was a problem hiding this comment.
That's something to consider, but wouldn't be that beneficial right now because nearly all the complexity is in the data processing and preparation (which we need either way) rather than the actual file content.
| loss = per_sample_loss.mean() | ||
| if target_format != TargetFormat.labels and group is not None: | ||
| all_reduce(loss, op=ReduceOp.MEAN, group=group) | ||
| all_reduce(loss, op=ReduceOp.SUM, group=group) |
| from fast_llm.engine.distributed.config import DistributedConfig, DistributedDimNames | ||
| from fast_llm.functional.autograd import grad_is_context, wrap_forward_backward | ||
| from fast_llm.functional.config import CrossEntropyImpl, DistillationLossImpl, TargetFormat, TritonConfig | ||
| from fast_llm.functional.config import ( |
There was a problem hiding this comment.
Ignoring changes to model head and reverse kl for now
| input_, output = grad_context | ||
| output.backward(output_grad) | ||
| return input_.grad | ||
| return input_.grad if input_.grad is not None else torch.zeros_like(input_) |
There was a problem hiding this comment.
Bad idea. This adds overhead to the first layer for all models
There was a problem hiding this comment.
Solved a problem though. Was the most pragmatic thing to do for Soham
| ): | ||
| scaled_target = target / teacher_softmax_temperature | ||
|
|
||
| scaled_target = torch.clamp(target, min=-50, max=50) |
| @@ -0,0 +1,183 @@ | |||
| import typing | |||
There was a problem hiding this comment.
Much simpler solution that can go directly in LanguageModelEmbedding and generalizes to other multimodal models:
- Move token ids to kwargs, replace input by the image or other embeddings. Same as here, but do it in the base class too. Use a placeholder or
Nonefor the LLM non-existant input. - Use a pre-built map (tensors) from input_ (image embeddings) to LM embeddings, and copy image/multimodal embeddings with a one-liner.
There was a problem hiding this comment.
These are all inputs, it's a bit limiting to have to declare one the "official" input and make the others kwargs. Btw, the generic case is multiple embeddings (images, audio, etc) and token id maps not just one
There was a problem hiding this comment.
Having a single input is a restriction from the engine, so not much we can do ATM. The new version (#369) is generic in the sense that it copies any kinds of embeddings from previous layers into token ids regardless of their source. It's still not combining inputs if they come from multiple previous layers, but that will be relatively straightforward to do when we need it.
| # Move to the next image in the input tensor | ||
| image_embedding_offset += num_patches | ||
|
|
||
| if self._use_absolute_position_embeddings: |
There was a problem hiding this comment.
Position embeddings should be masked?
| if self._use_absolute_position_embeddings: | ||
| position_ids = split(position_ids, group=group, dim=0) | ||
| # mask padded tokens | ||
| token_mask = tokens >= 0 |
There was a problem hiding this comment.
Is there a one-to-one correspondence between masked tokens and those replaced by image embeddings? If so this seems a bit redundant, and there are better solutions...
There was a problem hiding this comment.
In general no, masked tokens could be padding, too
There was a problem hiding this comment.
Do we need explicit masking for those though? Padding tokens are already masked for attention/ssm (through varlen) and loss (loss mask), so they don't contribute either way...
| image_embedding_offset += num_patches | ||
| if image_embedding_offset > patch_end_offset: | ||
| break | ||
| embeddings = reduce_forward(embeddings, group) |
There was a problem hiding this comment.
I think this is very wrong. The image embeddings are not vocab-parallel, so this incorrectly multiplies the image embeddings by TP.
There was a problem hiding this comment.
Actually this might be correct in a really complicated way that mixes the gathering of the sequence-parallel vision encoder outputs with their mapping to the lm embeddings. Not sure about non-sequence-parallel. Either way, this needs simplification.
| patch_position_ids = torch.cat(patch_position_ids) | ||
| kwargs[VisionEncoderKwargs.image_patches] = patches | ||
| kwargs[VisionTransformerKwargs.patch_position_ids] = patch_position_ids | ||
| kwargs[VisionEncoderKwargs.rotary_inv_freq] = create_inv_freqs( |
✨ Description
Workspace for dealing with the merge. Not intended to work yet.