Bug
KeyRerotationPress fails on multi-GPU setups when the model is loaded with device_map="auto".
When the model is split across multiple CUDA devices, tensors used during key re-rotation can be on different devices. In particular, selected_positions / delta_pos and module.rotary_emb.inv_freq may end up on different GPUs, causing _rerotate_cos_sin() to fail.
This affects FinchPress and potentially any other press that uses KeyRerotationPress.rerotate_keys().
Error:
RuntimeError: Expected all tensors to be on the same device, but found at least two devices, cuda:0 and cuda:1!
Relevant stack trace:
File ".../kvpress/presses/finch_press.py", line 113, in compress
keys = KeyRerotationPress.rerotate_keys(module, indices, keys)
File ".../kvpress/presses/key_rerotation_press.py", line 122, in rerotate_keys
new_cos, new_sin = KeyRerotationPress._rerotate_cos_sin(
keys, module.rotary_emb.inv_freq, indices
)
File ".../kvpress/presses/key_rerotation_press.py", line 90, in _rerotate_cos_sin
freqs = delta_pos.float() * inv_freq.float()
RuntimeError: Expected all tensors to be on the same device, but found at least two devices, cuda:0 and cuda:1!
To Reproduce
Run the following on a multi-GPU machine where device_map="auto" splits the model across devices:
from transformers import pipeline
from kvpress import FinchPress
import torch
device = "auto"
model = "meta-llama/Llama-3.1-8B-Instruct"
model_kwargs = {"attn_implementation": "flash_attention_2", "torch_dtype": torch.bfloat16}
pipe = pipeline(
"kv-press-text-generation",
model=model,
device_map=device,
model_kwargs=model_kwargs,
trust_remote_code=True,
)
context = "A very long text you want to compress once and for all"
question = "\nA question about the compressed context"
press = FinchPress(
compression_ratio=0.5,
normalize_scores=True,
)
press.update_model_and_tokenizer(pipe.model, pipe.tokenizer)
augmented_context = context + press.delimiter_token + question
result = pipe(
augmented_context,
question="",
press=press,
max_new_tokens=128,
)
answer = result["answer"]
print(answer)
Expected behavior: generation runs successfully.
Actual behavior: key re-rotation fails with a CUDA device mismatch.
Repository version
fa42106
Bug
KeyRerotationPressfails on multi-GPU setups when the model is loaded withdevice_map="auto".When the model is split across multiple CUDA devices, tensors used during key re-rotation can be on different devices. In particular,
selected_positions/delta_posandmodule.rotary_emb.inv_freqmay end up on different GPUs, causing_rerotate_cos_sin()to fail.This affects
FinchPressand potentially any other press that usesKeyRerotationPress.rerotate_keys().Error:
Relevant stack trace:
To Reproduce
Run the following on a multi-GPU machine where
device_map="auto"splits the model across devices:Expected behavior: generation runs successfully.
Actual behavior: key re-rotation fails with a CUDA device mismatch.
Repository version
fa42106