Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
[![Hugging Face Space](https://img.shields.io/badge/🤗%20Hugging%20Face-Space-blue)](https://huggingface.co/spaces/nvidia/kvpress)
[![Blog post](https://img.shields.io/badge/🤗%20Hugging%20Face-Blog-blue)](https://huggingface.co/blog/nvidia/kvpress)
[![Hugging Face Leaderboard](https://img.shields.io/badge/🤗%20HuggingFace-Leaderboard-orange)](https://huggingface.co/spaces/nvidia/kvpress-leaderboard)
[![Paper](https://img.shields.io/badge/📄%20arXiv-Paper-red)](https://arxiv.org/abs/2510.00636v1)

![kvpress](kvpress.jpg)

Expand Down
4 changes: 4 additions & 0 deletions evaluation/benchmarks/aime25/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
# AIME 25

This dataset contains problems from the American Invitational Mathematics Examination (AIME) 2025-I & II.
See https://huggingface.co/datasets/opencompass/AIME2025
2 changes: 2 additions & 0 deletions evaluation/benchmarks/aime25/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
# SPDX-FileCopyrightText: Copyright (c) 1993-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
24 changes: 24 additions & 0 deletions evaluation/benchmarks/aime25/calculate_metrics.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
# SPDX-FileCopyrightText: Copyright (c) 1993-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0

import pandas as pd


def extract_boxed(pred_answer):
try:
return str(pred_answer.split("boxed{")[1].split("}")[0])
except IndexError:
return None


def score_aime(pred_answer, true_answer):
return extract_boxed(pred_answer) == str(true_answer)


def calculate_metrics(df: pd.DataFrame) -> dict:
correct = 0
answered = 0
for index, row in df.iterrows():
correct += score_aime(row["predicted_answer"], row["answer"])
answered += "boxed{" in row["predicted_answer"]
return {"correct": correct, "answered": answered, "accuracy": correct / len(df), "total": len(df)}
4 changes: 4 additions & 0 deletions evaluation/benchmarks/math500/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
# MATH-500
Adapted from https://huggingface.co/datasets/HuggingFaceH4/MATH-500
This dataset contains a subset of 500 problems from the MATH benchmark that OpenAI created in their Let's Verify Step by Step paper. See their GitHub repo for the source file: https://github.com/openai/prm800k/tree/main?tab=readme-ov-file#math-splits

2 changes: 2 additions & 0 deletions evaluation/benchmarks/math500/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
# SPDX-FileCopyrightText: Copyright (c) 1993-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
24 changes: 24 additions & 0 deletions evaluation/benchmarks/math500/calculate_metrics.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
# SPDX-FileCopyrightText: Copyright (c) 1993-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0

import pandas as pd


def extract_boxed(pred_answer):
try:
return str(pred_answer.split("boxed{")[1].split("}")[0])
except IndexError:
return None


def score_aime(pred_answer, true_answer):
return extract_boxed(pred_answer) == str(true_answer)


def calculate_metrics(df: pd.DataFrame) -> dict:
correct = 0
answered = 0
for index, row in df.iterrows():
correct += score_aime(row["predicted_answer"], row["answer"])
answered += "boxed{" in row["predicted_answer"]
return {"correct": correct, "answered": answered, "accuracy": correct / len(df), "total": len(df)}
100 changes: 70 additions & 30 deletions evaluation/evaluate.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,15 +12,16 @@
import numpy as np
import pandas as pd
import torch
import yaml # type: ignore[import-untyped]
import yaml
from benchmarks.needle_in_haystack.utils import insert_needle_in_haystack
from datasets import load_dataset
from evaluate_registry import DATASET_REGISTRY, PRESS_REGISTRY, SCORER_REGISTRY
from fire import Fire
from tqdm import tqdm
from transformers import Pipeline, pipeline

from kvpress import ComposedPress, DuoAttentionPress, FinchPress, ObservedAttentionPress, ThinKPress
from kvpress import ComposedPress, DuoAttentionPress, FinchPress, ObservedAttentionPress, ScorerPress, ThinKPress
from kvpress.presses.decoding_press import DecodingPress

logger = logging.getLogger(__name__)

Expand All @@ -45,6 +46,11 @@ class EvaluationConfig:
compress_questions: bool = False
needle_depth: Optional[int] = None

# Decoding parameters
compression_interval: Optional[int] = None
target_size: Optional[int] = None
hidden_states_buffer_size: Optional[int] = None

# Output and logging
output_dir: str = "./results"
log_level: str = "INFO"
Expand Down Expand Up @@ -187,7 +193,7 @@ def __init__(self, config: EvaluationConfig):
"""
self.config = config
self.pipeline: Optional[Pipeline] = None # Will be set by _setup_model_pipeline()
self.press = None # Will be set by _setup_press()
self.press: None | ScorerPress = None # Will be set by _setup_press()
self.df: Optional[pd.DataFrame] = None # Will be set by _load_dataset()
self._setup_logging()
self._setup_deterministic_seeds()
Expand Down Expand Up @@ -264,6 +270,13 @@ def _setup_press(self):
assert key_channel_compression_ratio is not None, "key_channel_compression_ratio must be set for ThinKPress"
press.key_channel_compression_ratio = key_channel_compression_ratio
logger.info(f"Set ThinKPress key_channel_compression_ratio to {key_channel_compression_ratio}")
elif isinstance(press, DecodingPress):
press.compression_interval = self.config.compression_interval or press.compression_interval
press.target_size = self.config.target_size or press.target_size
press.hidden_states_buffer_size = self.config.hidden_states_buffer_size or press.hidden_states_buffer_size
logger.info(
f"Set DecodingPress compression_interval to {self.config.compression_interval}, target_size to {self.config.target_size}, hidden_states_buffer_size to {self.config.hidden_states_buffer_size}"
)
else:
if hasattr(press, "compression_ratio"):
press.compression_ratio = compression_ratio
Expand Down Expand Up @@ -309,7 +322,9 @@ def _load_and_prepare_dataset(self):
# FinchPress uses a delimiter token to separate context and question
# So we need to update the tokenizer and the model embeddings.
logger.info("FinchPress detected, updating model and tokenizer with delimiter token.")
self.press.update_model_and_tokenizer(self.pipeline.model, self.pipeline.tokenizer) # type: ignore[attr-defined]
self.press.update_model_and_tokenizer(
self.pipeline.model, self.pipeline.tokenizer
) # type: ignore[attr-defined]
df["context"] = df["context"] + self.press.delimiter_token # type: ignore[attr-defined, index]

if self.config.compress_questions:
Expand Down Expand Up @@ -364,30 +379,54 @@ def _run_inference(self):
"""

self.df["predicted_answer"] = None # type: ignore[index]
df_context_grouped = self.df.groupby("context") # type: ignore[union-attr]
assert all(
df_context_grouped["answer_prefix"].nunique() == 1
), "Inconsistent 'answer_prefix' within the same context group detected."

logger.info("Starting inference...")
for context, df_group in tqdm(df_context_grouped, total=self.df["context"].nunique(), desc="Running Inference"): # type: ignore[union-attr]
questions = df_group["question"].to_list()
# Use max_new_tokens from config, or fallback to dataset's default for the task
max_new_tokens = self.config.max_new_tokens or df_group["max_new_tokens"].iloc[0]
answer_prefix = df_group["answer_prefix"].iloc[0]

output = self.pipeline( # type: ignore[misc]
context,
questions=questions,
answer_prefix=answer_prefix,
press=self.press,
max_new_tokens=max_new_tokens,
max_context_length=self.config.max_context_length,
)
self.df.loc[df_group.index, "predicted_answer"] = output["answers"] # type: ignore[union-attr]
# Store the actual compression ratio used (if the press has one)
self.df.loc[df_group.index, "compression_ratio"] = self.press.compression_ratio if self.press is not None else 0.0 # type: ignore[union-attr, attr-defined]
torch.cuda.empty_cache() # Clear CUDA cache to free up memory

if isinstance(self.press, DecodingPress):
logger.info("DecodingPress detected, running inference for each context-question pair.")
for index, row in tqdm(self.df.iterrows(), total=len(self.df), desc="Running Inference"):
context = row["context"]
question = row["question"]
answer_prefix = row["answer_prefix"]
max_new_tokens = self.config.max_new_tokens or row["max_new_tokens"]
output = self.pipeline(
context,
question=question,
answer_prefix=answer_prefix,
press=self.press,
max_new_tokens=max_new_tokens,
max_context_length=self.config.max_context_length,
)
self.df.loc[index, "predicted_answer"] = output["answer"] # type: ignore[union-attr]
torch.cuda.empty_cache() # Clear CUDA cache to free up memory

else:
df_context_grouped = self.df.groupby("context") # type: ignore[union-attr]
assert all(
df_context_grouped["answer_prefix"].nunique() == 1
), "Inconsistent 'answer_prefix' within the same context group detected."

logger.info("Starting inference...")
for context, df_group in tqdm(
df_context_grouped, total=self.df["context"].nunique(), desc="Running Inference"
): # type: ignore[union-attr]
questions = df_group["question"].to_list()
# Use max_new_tokens from config, or fallback to dataset's default for the task
max_new_tokens = self.config.max_new_tokens or df_group["max_new_tokens"].iloc[0]
answer_prefix = df_group["answer_prefix"].iloc[0]

output = self.pipeline( # type: ignore[misc]
context,
questions=questions,
answer_prefix=answer_prefix,
press=self.press,
max_new_tokens=max_new_tokens,
max_context_length=self.config.max_context_length,
)
self.df.loc[df_group.index, "predicted_answer"] = output["answers"] # type: ignore[union-attr]
# Store the actual compression ratio used (if the press has one)
self.df.loc[df_group.index, "compression_ratio"] = (
self.press.compression_ratio if self.press is not None else 0.0 # type: ignore[attr-defined]
) # type: ignore[union-attr, attr-defined]
torch.cuda.empty_cache() # Clear CUDA cache to free up memory

logger.info("Inference completed.")

Expand All @@ -403,7 +442,9 @@ def _save_results(self, save_filename: Path):
if save_filename.exists():
logger.warning(f"Results CSV already exists at {save_filename}. Overwriting.")

self.df[list(set(self.df.columns) - set(["context"]))].to_csv(str(save_filename), index=False) # type: ignore[index]
self.df[list(set(self.df.columns) - set(["context"]))].to_csv(
str(save_filename), index=False
) # type: ignore[index]
logger.info(f"Results saved to {save_filename}")

def _calculate_and_save_metrics(self, save_filename: Path):
Expand All @@ -425,7 +466,6 @@ def _calculate_and_save_metrics(self, save_filename: Path):
json.dump(metrics, f, indent=4) # Pretty print JSON

logger.info(f"Metrics saved to {save_filename}")
logger.info(f"Average compression ratio: {self.df['compression_ratio'].mean():.2f}") # type: ignore[index]
logger.info(f"Metrics:\n{json.dumps(metrics, indent=2)}")

def run_evaluation(self):
Expand Down
15 changes: 15 additions & 0 deletions evaluation/evaluate_registry.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,13 @@
# SPDX-FileCopyrightText: Copyright (c) 1993-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0

from benchmarks.aime25.calculate_metrics import calculate_metrics as aime25_scorer
from benchmarks.infinite_bench.calculate_metrics import calculate_metrics as infinite_bench_scorer
from benchmarks.longbench.calculate_metrics import calculate_metrics as longbench_scorer
from benchmarks.longbench.calculate_metrics import calculate_metrics_e as longbench_scorer_e
from benchmarks.longbenchv2.calculate_metrics import calculate_metrics as longbenchv2_scorer
from benchmarks.loogle.calculate_metrics import calculate_metrics as loogle_scorer
from benchmarks.math500.calculate_metrics import calculate_metrics as math500_scorer
from benchmarks.needle_in_haystack.calculate_metrics import calculate_metrics as needle_in_haystack_scorer
from benchmarks.ruler.calculate_metrics import calculate_metrics as ruler_scorer
from benchmarks.zero_scrolls.calculate_metrics import calculate_metrics as zero_scrolls_scorer
Expand All @@ -32,6 +34,7 @@
ThinKPress,
TOVAPress,
)
from kvpress.presses.decoding_press import DecodingPress

# These dictionaries define the available datasets, scorers, and KVPress methods for evaluation.
DATASET_REGISTRY = {
Expand All @@ -43,6 +46,9 @@
"longbench-e": "Xnhyacinth/LongBench",
"longbench-v2": "Xnhyacinth/LongBench-v2",
"needle_in_haystack": "alessiodevoto/paul_graham_essays",
# Datasets used to be used for decoding compression
"aime25": "alessiodevoto/aime25",
"math500": "alessiodevoto/math500",
}

SCORER_REGISTRY = {
Expand All @@ -54,6 +60,8 @@
"longbench-e": longbench_scorer_e,
"longbench-v2": longbenchv2_scorer,
"needle_in_haystack": needle_in_haystack_scorer,
"aime25": aime25_scorer,
"math500": math500_scorer,
}


Expand Down Expand Up @@ -84,4 +92,11 @@
"think": ThinKPress(),
"tova": TOVAPress(),
"no_press": None,
"decoding_knorm": DecodingPress(base_press=KnormPress()),
"decoding_streaming_llm": DecodingPress(base_press=StreamingLLMPress()),
"decoding_tova": DecodingPress(base_press=TOVAPress()),
"decoding_qfilter": DecodingPress(base_press=QFilterPress()),
"decoding_adakv_expected_attention_e2": DecodingPress(base_press=AdaKVPress(ExpectedAttentionPress(epsilon=1e-2))),
"decoding_adakv_snapkv": DecodingPress(base_press=AdaKVPress(SnapKVPress())),
"decoding_keydiff": DecodingPress(base_press=KeyDiffPress()),
}
5 changes: 3 additions & 2 deletions kvpress/presses/decoding_press.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
import torch.nn as nn
from transformers.cache_utils import QuantizedCache

from kvpress import AdaKVPress
from kvpress.presses.base_press import BasePress
from kvpress.presses.scorer_press import ScorerPress
from kvpress.presses.utils import extract_keys_and_values
Expand Down Expand Up @@ -39,14 +40,14 @@ class DecodingPress(BasePress):
current hidden state for compression scoring.
"""

base_press: ScorerPress
base_press: ScorerPress | AdaKVPress
compression_interval: int = 128
target_size: int = 1024
hidden_states_buffer_size: int = 128

def __post_init__(self):
# Buffer to store hidden states during decoding (per layer)
assert isinstance(self.base_press, ScorerPress), "DecodingPress requires a ScorerPress as input"
assert isinstance(self.base_press, (ScorerPress, AdaKVPress)), "DecodingPress requires a ScorerPress as input"
self.hidden_states_buffer = defaultdict(list) # Per-layer buffer
self.layer_step_counts = defaultdict(int) # Track step count per layer

Expand Down
Loading