Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 5 additions & 8 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -130,6 +130,7 @@ Several presses inherit from `ScorerPress` ([source](kvpress/presses/scorer_pres
- `NonCausalAttnPress` ([source](kvpress/presses/non_causal_attention_press.py), [paper](https://arxiv.org/abs/2507.08143)): evicts tokens based on non-causal chunked attention scores.
- `LeverageScorePress` ([source](kvpress/presses/leverage_press.py), [paper](https://arxiv.org/abs/2507.08143)): evicts tokens based on approximate statistical leverage (i.e we preserve outliers in the key space).
- `CompactorPress` ([source](kvpress/presses/compactor_press.py), [paper](https://arxiv.org/abs/2507.08143)): blends `NonCausalAttnPress` and `LeverageScorePress` based on the compression_ratio.
- `CURPress` ([source](kvpress/presses/cur_press.py), [paper](https://arxiv.org/abs/2509.15038)): prune keys and values based on the CUR decomposition using approximate leverage scores.

Some presses rely on a different logic:
- `ThinKPress` ([source](kvpress/presses/think_press.py), [paper](https://arxiv.org/pdf/2407.21018)): compress the dimensions of the keys based on the channel attention score on the last queries
Expand All @@ -147,28 +148,24 @@ Finally we provide wrapper presses that can be combined with other presses:
- `ChunkPress` ([source](kvpress/presses/chunk_press.py), [paper](https://direct.mit.edu/tacl/article/doi/10.1162/tacl_a_00716/125280)): compress the KV cache on each sequence chunk separately. This can yield to more uniform compression across long sequences
- `CriticalKVPress` and `CriticalAdaKVPress` ([source](kvpress/presses/criticalkv_press.py), [paper](https://arxiv.org/abs/2502.03805)): refine the scores using the L1 norm of Wo @ values, coupled with a two-stage selection.
- `BlockPress` ([source](kvpress/presses/block_press.py), [paper](https://arxiv.org/abs/2504.15364)): segments input sequence into non-overlapping blocks and compresses iteratively.
- `DeocdingPress` ([source](kvpress/presses/decoding_press.py)): Allows for compression during decoding, see decoding section in this README.
- `DecodingPress` ([source](kvpress/presses/decoding_press.py)): Allows for compression during decoding, see decoding section in this README.
- `PrefillDecodingPress` ([source](kvpress/presses/prefill_decoding_press.py)): Allows to compress both during prefilling and during decoding.
- `CURPress` ([source](kvpress/presses/cur_press.py), [paper](https://arxiv.org/abs/2509.15038)): prune keys and
values based on the CUR decomposition using approximate leverage scores.

For a detailed list of existing KV cache compression methods, check [Awesome-KV-Cache-Compression](https://github.com/October2001/Awesome-KV-Cache-Compression) or [Awesome-LLM-Compression](https://github.com/HuangOwen/Awesome-LLM-Compression?tab=readme-ov-file#kv-cache-compression)


## Evaluation
We provide a simple CLI to evaluate the performance of different presses on several long-context datasets.

- Accuracy: Test your method on popular benchmarks directly using our CLI. For a broader comparison, check out our public [Hugging Face Leaderboard](https://huggingface.co/spaces/nvidia/kvpress-leaderboard)
, where you can see how various methods stack up against each other.
We provide a simple CLI to evaluate the performance of different presses on several long-context datasets.

- Accuracy: Test your method on popular benchmarks directly using our CLI.
- Speed and Memory: The [speed_and_memory](notebooks/speed_and_memory.ipynb) notebook can help you measure peak memory usage and total time gain.

Please refer to the [evaluation](evaluation/README.md) directory in this repo for more details and results.

Below we report the average performance on the RULER dataset with 4k context length for different presses, from our [![Hugging Face Leaderboard](https://img.shields.io/badge/🤗%20HuggingFace-Leaderboard-orange)](https://huggingface.co/spaces/nvidia/kvpress-leaderboard)

<p>
<img src="evaluation/assets/leaderboard_plot_score.png" alt="Leaderboard">
<img src="leaderboard_plot_score.png" alt="Leaderboard">
</p>


Expand Down
20 changes: 18 additions & 2 deletions evaluation/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -46,10 +46,26 @@ At the moment, we support the following standard popular benchmarks:
- [Zero Scrolls](benchmarks/zero_scrolls/README.md) ([hf link](https://huggingface.co/datasets/simonjegou/zero_scrolls))
- [Infinitebench](benchmarks/infinite_bench/README.md) ([hf link](https://huggingface.co/datasets/MaxJeblick/InfiniteBench))
- [longbench](benchmarks/longbench/README.md)([hf link](https://huggingface.co/datasets/Xnhyacinth/LongBench))
- [longbench-v2](benchmarks/longbenchv2/README.md)([hf link](https://huggingface.co/datasets/Xnhyacinth/LongBench-v2))
- [longbench-v2](benchmarks/longbenchv2/README.md)([hf link](https://huggingface.co/datasets/simonjegou/LongBench-v2))
- [Needle in a Haystack](benchmarks/needle_in_haystack/README.md)([hf link][Paul Graham's essays](https://huggingface.co/datasets/alessiodevoto/paul_graham_essays))

📚 **For detailed information** about each dataset or implementing custom benchmarks, see the individual README files in the benchmarks directory.
Each dataset directory is structured as follows:

```bash
$dataset
├── README.md
├── calculate_metrics.py
├── create_huggingface_dataset.py
```

Where:
- `create_huggingface_dataset.py` is a script that generates the Hugging Face dataset from the original dataset. Each dataset is associated with a set of parquet files with the following structure:
- `context`: ...
- `question`: ...
- `answer_prefix`: ...
- `answer`: ...
- `max_new_tokens`: ...
- `calculate_metrics.py` is a script that calculates the metrics based on the output of `evaluate.py`


### Multi GPU Evaluation
Expand Down
Binary file removed evaluation/assets/infinitebench_kv_retrieval.png
Binary file not shown.
Binary file not shown.
Binary file removed evaluation/assets/infinitebench_longbook_qa_eng.png
Binary file not shown.
Binary file not shown.
Binary file removed evaluation/assets/leaderboard_plot_score.png
Binary file not shown.
Binary file removed evaluation/assets/loogle_longdep_qa.png
Binary file not shown.
Binary file removed evaluation/assets/loogle_shortdep_cloze.png
Binary file not shown.
Binary file removed evaluation/assets/loogle_shortdep_qa.png
Binary file not shown.
Binary file removed evaluation/assets/peak_memory_consumption.png
Binary file not shown.
Binary file removed evaluation/assets/peak_memory_consumption_xkcd.png
Binary file not shown.
Binary file removed evaluation/assets/ruler_4096_average score.png
Binary file not shown.
Binary file removed evaluation/assets/ruler_4096_cwe.png
Binary file not shown.
Binary file removed evaluation/assets/ruler_4096_fwe.png
Binary file not shown.
Binary file removed evaluation/assets/ruler_4096_niah_multikey_1.png
Binary file not shown.
Binary file removed evaluation/assets/ruler_4096_niah_multikey_2.png
Binary file not shown.
Binary file removed evaluation/assets/ruler_4096_niah_multikey_3.png
Binary file not shown.
Binary file removed evaluation/assets/ruler_4096_niah_multiquery.png
Binary file not shown.
Binary file removed evaluation/assets/ruler_4096_niah_multivalue.png
Binary file not shown.
Binary file removed evaluation/assets/ruler_4096_niah_single_1.png
Binary file not shown.
Binary file removed evaluation/assets/ruler_4096_niah_single_2.png
Binary file not shown.
Binary file removed evaluation/assets/ruler_4096_niah_single_3.png
Binary file not shown.
Binary file removed evaluation/assets/ruler_4096_qa_1.png
Binary file not shown.
Binary file removed evaluation/assets/ruler_4096_qa_2.png
Binary file not shown.
Binary file removed evaluation/assets/ruler_4096_vt.png
Binary file not shown.
Binary file removed evaluation/assets/ruler_llama_xkcd.png
Binary file not shown.
102 changes: 0 additions & 102 deletions evaluation/benchmarks/README.md

This file was deleted.

2 changes: 1 addition & 1 deletion evaluation/evaluate_config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -22,5 +22,5 @@ device: null # Device to use (null = auto-detect, "cuda:0", "cpu", etc.)
# You can add any model kwargs here.
model_kwargs:
attn_implementation: null
torch_dtype: "auto"
dtype: "auto"

2 changes: 1 addition & 1 deletion kvpress/presses/cur_press.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
import torch
import torch.nn.functional as F

from kvpress import ScorerPress
from kvpress.presses.scorer_press import ScorerPress


@dataclass
Expand Down
4 changes: 2 additions & 2 deletions kvpress/presses/expected_attention_press.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
from transformers.models.llama.modeling_llama import repeat_kv

from kvpress.presses.scorer_press import ScorerPress
from kvpress.utils import get_query_states
from kvpress.utils import get_prerope_query_states


@dataclass
Expand Down Expand Up @@ -68,7 +68,7 @@ def get_query_statistics(self, module: nn.Module, hidden_states: torch.Tensor):

# Remove first hidden_states that likely contain outliers
h = hidden_states[:, self.n_sink :]
query_states = get_query_states(module, h)
query_states = get_prerope_query_states(module, h)

# Query mean
mu = query_states.mean(dim=2, keepdim=True)
Expand Down
2 changes: 1 addition & 1 deletion kvpress/presses/expected_attention_with_stats.py
Original file line number Diff line number Diff line change
Expand Up @@ -269,7 +269,7 @@ def main(
text_column: Column name containing text in the dataset
device_map: Device mapping for the model
"""
model = AutoModelForCausalLM.from_pretrained(model_name, device_map=device_map, torch_dtype=torch.bfloat16).eval()
model = AutoModelForCausalLM.from_pretrained(model_name, device_map=device_map, dtype=torch.bfloat16).eval()

_, mu, cov = collect_queries(model, dataset_name, num_samples, sample_seq_len, n_sink, text_column)

Expand Down
4 changes: 2 additions & 2 deletions kvpress/presses/kvzip_press.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
from transformers.models.llama.modeling_llama import rotate_half

from kvpress.presses.base_press import SUPPORTED_MODELS, BasePress
from kvpress.utils import extract_keys_and_values, get_query_states
from kvpress.utils import extract_keys_and_values, get_prerope_query_states

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -300,7 +300,7 @@ def score_kvzip(
head_dim = module.head_dim
num_key_value_groups = num_heads // num_heads_kv

queries = get_query_states(module, hidden_states)
queries = get_prerope_query_states(module, hidden_states)

# Apply RoPE
cos, sin = kwargs["position_embeddings"]
Expand Down
4 changes: 2 additions & 2 deletions kvpress/presses/non_causal_attention_press.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
from transformers.models.llama.modeling_llama import repeat_kv, rotate_half

from kvpress.presses.scorer_press import ScorerPress
from kvpress.utils import get_query_states
from kvpress.utils import get_prerope_query_states


@dataclass
Expand Down Expand Up @@ -105,7 +105,7 @@ def score(
assert keys.shape[-2] == n_queries, "NonCausalAttnPress only supports prefill"

cos, sin = kwargs["position_embeddings"]
q = get_query_states(module, hidden_states) # (B, H_q, S, d)
q = get_prerope_query_states(module, hidden_states) # (B, H_q, S, d)

q_len = q.shape[-2]
num_kv_groups = q.shape[1] // values.shape[1]
Expand Down
4 changes: 2 additions & 2 deletions kvpress/presses/snapkv_press.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
from transformers.models.llama.modeling_llama import repeat_kv, rotate_half

from kvpress.presses.scorer_press import ScorerPress
from kvpress.utils import get_query_states
from kvpress.utils import get_prerope_query_states


@dataclass
Expand Down Expand Up @@ -50,7 +50,7 @@ def compute_window_attention(module, hidden_states, keys, window_size, position_
num_key_value_groups = num_heads // module.config.num_key_value_heads

# Get last window_size queries
query_states = get_query_states(module, hidden_states[:, -window_size:])
query_states = get_prerope_query_states(module, hidden_states[:, -window_size:])

# Apply RoPE
cos, sin = position_embeddings
Expand Down
4 changes: 2 additions & 2 deletions kvpress/presses/think_press.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
from transformers.models.llama.modeling_llama import rotate_half

from kvpress.presses.base_press import BasePress
from kvpress.utils import get_query_states
from kvpress.utils import get_prerope_query_states


@dataclass
Expand Down Expand Up @@ -45,7 +45,7 @@ def compute_window_queries(self, module, hidden_states, position_embeddings):
Re-compute the last window_size query states
"""
# Get last self.window_size queries
query_states = get_query_states(module, hidden_states[:, -self.window_size :])
query_states = get_prerope_query_states(module, hidden_states[:, -self.window_size :])

# Apply RoPE
cos, sin = position_embeddings
Expand Down
2 changes: 1 addition & 1 deletion kvpress/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
from transformers.models.qwen3.modeling_qwen3 import Qwen3Attention


def get_query_states(module: nn.Module, hidden_states: torch.Tensor) -> torch.Tensor:
def get_prerope_query_states(module: nn.Module, hidden_states: torch.Tensor) -> torch.Tensor:
"""
Extracts the query states from a given attention module and hidden states tensor.

Expand Down
Binary file added leaderboard_plot_score.png
4 changes: 2 additions & 2 deletions notebooks/expected_attention.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,7 @@
"ckpt = \"meta-llama/Meta-Llama-3.1-8B-Instruct\"\n",
"# ckpt = \"mistralai/Mistral-Nemo-Instruct-2407\"\n",
"# ckpt = \"microsoft/Phi-3.5-mini-instruct\"\n",
"pipe = pipeline(\"kv-press-text-generation\", model=ckpt, device=device, torch_dtype=\"auto\", model_kwargs={\"attn_implementation\":\"flash_attention_2\"})\n",
"pipe = pipeline(\"kv-press-text-generation\", model=ckpt, device=device, dtype=\"auto\", model_kwargs={\"attn_implementation\":\"flash_attention_2\"})\n",
"\n",
"# Load data\n",
"url = \"https://en.wikipedia.org/wiki/Nvidia\"\n",
Expand Down Expand Up @@ -199,7 +199,7 @@
"decoder_layer = pipe.model.model.layers[layer_idx]\n",
"self_attn = decoder_layer.self_attn\n",
"self_attn.rotary_emb = pipe.model.model.rotary_emb\n",
"attention_mask = AttentionMaskConverter(is_causal=True).to_causal_4d(1, n_tokens, n_tokens, pipe.torch_dtype, pipe.device)\n",
"attention_mask = AttentionMaskConverter(is_causal=True).to_causal_4d(1, n_tokens, n_tokens, pipe.dtype, pipe.device)\n",
"\n",
"with torch.no_grad():\n",
" # Compute expected attention (need lot of vRAM to keep all tokens)\n",
Expand Down
2 changes: 1 addition & 1 deletion notebooks/new_press.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@
"device = \"cuda:0\"\n",
"ckpt = \"Qwen/Qwen2.5-1.5B-Instruct\"\n",
"attn_implementation = \"flash_attention_2\"\n",
"pipe = pipeline(\"kv-press-text-generation\", model=ckpt, device=device, torch_dtype=\"auto\", model_kwargs={\"attn_implementation\":attn_implementation})"
"pipe = pipeline(\"kv-press-text-generation\", model=ckpt, device=device, dtype=\"auto\", model_kwargs={\"attn_implementation\":attn_implementation})"
]
},
{
Expand Down
2 changes: 1 addition & 1 deletion notebooks/per_layer_compression_demo.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@
"device = \"cuda:0\"\n",
"ckpt = \"microsoft/Phi-3.5-mini-instruct\"\n",
"attn_implementation = \"flash_attention_2\"\n",
"pipe = pipeline(\"kv-press-text-generation\", model=ckpt, device=device, torch_dtype=\"auto\", model_kwargs={\"attn_implementation\":attn_implementation})"
"pipe = pipeline(\"kv-press-text-generation\", model=ckpt, device=device, dtype=\"auto\", model_kwargs={\"attn_implementation\":attn_implementation})"
]
},
{
Expand Down
4 changes: 2 additions & 2 deletions notebooks/speed_and_memory.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -110,7 +110,7 @@
" torch.cuda.reset_peak_memory_stats()\n",
" torch.cuda.empty_cache()\n",
" idle_peak_memory = torch.cuda.max_memory_allocated()\n",
" model = AutoModelForCausalLM.from_pretrained(ckpt, torch_dtype=\"auto\", attn_implementation=\"flash_attention_2\").to(device)\n",
" model = AutoModelForCausalLM.from_pretrained(ckpt, dtype=\"auto\", attn_implementation=\"flash_attention_2\").to(device)\n",
" initial_peak_memory = torch.cuda.max_memory_allocated()\n",
"\n",
" inputs =torch.arange(n_tokens).reshape([1, n_tokens]).to(device)\n",
Expand Down Expand Up @@ -160,7 +160,7 @@
" torch.cuda.reset_peak_memory_stats()\n",
" torch.cuda.empty_cache()\n",
" idle_peak_memory = torch.cuda.max_memory_allocated()\n",
" model = AutoModelForCausalLM.from_pretrained(ckpt, torch_dtype=\"auto\", attn_implementation=\"flash_attention_2\").to(device)\n",
" model = AutoModelForCausalLM.from_pretrained(ckpt, dtype=\"auto\", attn_implementation=\"flash_attention_2\").to(device)\n",
" # disable EosTokenCriteria stopping criteria\n",
" model.generation_config.eos_token_id = None\n",
" model.generation_config.stop_strings = None\n",
Expand Down
Loading