Skip to content

KeyRerotationPress fails on multi-GPU setups #240

@giulio98

Description

@giulio98

Bug

KeyRerotationPress fails on multi-GPU setups when the model is loaded with device_map="auto".

When the model is split across multiple CUDA devices, tensors used during key re-rotation can be on different devices. In particular, selected_positions / delta_pos and module.rotary_emb.inv_freq may end up on different GPUs, causing _rerotate_cos_sin() to fail.

This affects FinchPress and potentially any other press that uses KeyRerotationPress.rerotate_keys().

Error:

RuntimeError: Expected all tensors to be on the same device, but found at least two devices, cuda:0 and cuda:1!

Relevant stack trace:

File ".../kvpress/presses/finch_press.py", line 113, in compress
    keys = KeyRerotationPress.rerotate_keys(module, indices, keys)

File ".../kvpress/presses/key_rerotation_press.py", line 122, in rerotate_keys
    new_cos, new_sin = KeyRerotationPress._rerotate_cos_sin(
        keys, module.rotary_emb.inv_freq, indices
    )

File ".../kvpress/presses/key_rerotation_press.py", line 90, in _rerotate_cos_sin
    freqs = delta_pos.float() * inv_freq.float()

RuntimeError: Expected all tensors to be on the same device, but found at least two devices, cuda:0 and cuda:1!

To Reproduce

Run the following on a multi-GPU machine where device_map="auto" splits the model across devices:

from transformers import pipeline
from kvpress import FinchPress
import torch

device = "auto"
model = "meta-llama/Llama-3.1-8B-Instruct"
model_kwargs = {"attn_implementation": "flash_attention_2", "torch_dtype": torch.bfloat16}

pipe = pipeline(
    "kv-press-text-generation",
    model=model,
    device_map=device,
    model_kwargs=model_kwargs,
    trust_remote_code=True,
)

context = "A very long text you want to compress once and for all"
question = "\nA question about the compressed context"

press = FinchPress(
    compression_ratio=0.5,
    normalize_scores=True,
)

press.update_model_and_tokenizer(pipe.model, pipe.tokenizer)

augmented_context = context + press.delimiter_token + question

result = pipe(
    augmented_context,
    question="",
    press=press,
    max_new_tokens=128,
)

answer = result["answer"]
print(answer)

Expected behavior: generation runs successfully.

Actual behavior: key re-rotation fails with a CUDA device mismatch.

Repository version

fa42106

Metadata

Metadata

Assignees

Labels

bugSomething isn't working

Type

No type
No fields configured for issues without a type.

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions