From e9a5339d9f9b51d3adfa5d900cba04e689834ad4 Mon Sep 17 00:00:00 2001 From: alessiodevoto Date: Wed, 27 Aug 2025 12:32:30 +0000 Subject: [PATCH 01/10] ea with stats Signed-off-by: alessiodevoto --- kvpress/__init__.py | 2 + kvpress/presses/expected_attention_press.py | 57 ++-- .../presses/expected_attention_with_stats.py | 275 ++++++++++++++++++ 3 files changed, 317 insertions(+), 17 deletions(-) create mode 100644 kvpress/presses/expected_attention_with_stats.py diff --git a/kvpress/__init__.py b/kvpress/__init__.py index bae791d8..e9b07e15 100644 --- a/kvpress/__init__.py +++ b/kvpress/__init__.py @@ -30,6 +30,7 @@ from kvpress.presses.streaming_llm_press import StreamingLLMPress from kvpress.presses.think_press import ThinKPress from kvpress.presses.tova_press import TOVAPress +from kvpress.presses.expected_attention_with_stats import ExpectedAttentionStatsPress # Patch the attention functions to support head-wise compression patch_attention_functions() @@ -63,4 +64,5 @@ "BlockPress", "KeyDiffPress", "KVzipPress", + "ExpectedAttentionStatsPress", ] diff --git a/kvpress/presses/expected_attention_press.py b/kvpress/presses/expected_attention_press.py index e22e2b02..618f12e8 100644 --- a/kvpress/presses/expected_attention_press.py +++ b/kvpress/presses/expected_attention_press.py @@ -65,7 +65,6 @@ def get_query_statistics(self, module: nn.Module, hidden_states: torch.Tensor): """ q_len = hidden_states.shape[1] - head_dim = module.head_dim # Remove first hidden_states that likely contain outliers h = hidden_states[:, self.n_sink :] @@ -81,36 +80,60 @@ def get_query_statistics(self, module: nn.Module, hidden_states: torch.Tensor): cov = torch.einsum("bnsi,bnsj->bnij", centered_states, centered_states) / h.shape[1] mu = mu.squeeze(2) - # RoPE rotation matrix on next n_future_positions + # Apply RoPE to the mean and covariance matrix of the queries + mu, cov = self.apply_avg_rope(module, mu, cov, q_len) + + # Instead of using the average rotation matrix, we could use a mixture of gaussian statistics to + # estimate mean and covariance. Estimation is better, but end-to-end performance was lower. + # mu = torch.einsum("bhj, fij -> bhfi", mu, R) + # mean_mu = mu.mean(dim=2, keepdim=True) + # if self.use_covariance: + # cov = torch.einsum("fki, bhkl, fjl -> bhfij", R, cov, R) + # cov = cov.mean(dim=2) + # cov += torch.einsum("bhfi, bhfj -> bhji", mu - mean_mu, mu - mean_mu) / self.n_future_positions + # mu = mean_mu.squeeze(2) + + return mu, cov + + def apply_avg_rope(self, module: nn.Module, mu: torch.Tensor, cov: torch.Tensor, q_len: int): + """ + Apply average RoPE to the mean and covariance matrix of the queries + + Parameters + ---------- + module : nn.Module + The module to apply RoPE to. + mu : torch.Tensor + The mean of the queries. + cov : torch.Tensor + The covariance matrix of the queries. + q_len : int + The length of the queries. + + Returns + ------- + mu : torch.Tensor + The mean of the queries after RoPE. + cov : torch.Tensor + The covariance matrix of the queries after RoPE. + """ position_ids = torch.arange(q_len, q_len + self.n_future_positions).unsqueeze(0).to(mu.device) + head_dim = module.head_dim cos, sin = module.rotary_emb(mu, position_ids) cos, sin = cos[0], sin[0] - Id = torch.eye(head_dim, device=cos.device, dtype=cos.dtype) P = torch.zeros((head_dim, head_dim), device=cos.device, dtype=cos.dtype) P[head_dim // 2 :, : head_dim // 2], P[: head_dim // 2, head_dim // 2 :] = torch.eye(head_dim // 2), -torch.eye( head_dim // 2 ) R = cos.unsqueeze(1) * Id + sin.unsqueeze(1) * P - - # Apply average rotation to the mean and covariance R = R.mean(dim=0).to(mu.device) mu = torch.matmul(mu, R.T) - if self.use_covariance: + if cov is not None: cov = torch.matmul(R, torch.matmul(cov, R.T)) - - # Instead of using the average rotation matrix, we could use a mixture of gaussian statistics to - # estimate mean and covariance. Estimation is better, but end-to-end performance was lower. - # mu = torch.einsum("bhj, fij -> bhfi", mu, R) - # mean_mu = mu.mean(dim=2, keepdim=True) - # if self.use_covariance: - # cov = torch.einsum("fki, bhkl, fjl -> bhfij", R, cov, R) - # cov = cov.mean(dim=2) - # cov += torch.einsum("bhfi, bhfj -> bhji", mu - mean_mu, mu - mean_mu) / self.n_future_positions - # mu = mean_mu.squeeze(2) - return mu, cov + def score( self, module: nn.Module, diff --git a/kvpress/presses/expected_attention_with_stats.py b/kvpress/presses/expected_attention_with_stats.py new file mode 100644 index 00000000..c198d5a9 --- /dev/null +++ b/kvpress/presses/expected_attention_with_stats.py @@ -0,0 +1,275 @@ +# SPDX-FileCopyrightText: Copyright (c) 1993-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 + +import os +from dataclasses import dataclass, field +from typing import Optional +import importlib +from contextlib import contextmanager +import torch +from torch import nn + +from datasets import load_dataset +from tqdm import tqdm +from transformers import AutoTokenizer, PreTrainedModel +from huggingface_hub import PyTorchModelHubMixin +from transformers import AutoModelForCausalLM + +from kvpress.presses.expected_attention_press import ExpectedAttentionPress + + +@dataclass +class ExpectedAttentionStatsPress(ExpectedAttentionPress): + """ + Expected attention press that automatically loads pre-computed query statistics. + + + Parameters + ---------- + compression_ratio : float, default=0.0 + Fraction of key-value pairs to remove during compression. + n_future_positions : int, default=512 + Number of future positions to consider when computing expected attention. + n_sink : int, default=4 + Number of initial tokens to exclude from compression (sink tokens). + use_covariance : bool, default=True + Whether to include covariance information in expected attention computation. + use_vnorm : bool, default=True + Whether to rescale scores using value vector norms. + epsilon : float, default=0.0 + Small constant added to scores before value norm rescaling. + stats_dataset : str, default="kmfoda/booksum" + Dataset used to compute the statistics. + n_samples : int, default=100 + Number of samples used to compute the statistics. + sample_seq_len : int, default=1000 + Sequence length used to compute the statistics. + """ + + # Override parent defaults to enable stats by default + sample_seq_len: int = 1000 + n_samples: int = 100 + stats_dataset: str = "kmfoda/booksum" + stats_folder: Optional[str] = None + + + mu: torch.Tensor = field(init=False, default=None) # initialized in __post_init_from_model__ + cov: torch.Tensor = field(init=False, default=None) # initialized in __post_init_from_model__ + + + def get_query_statistics(self, module: nn.Module, hidden_states: torch.Tensor): + """ + Compute the mean and covariance matrix of the queries and apply average RoPE to them. + """ + print(f"Applying average RoPE to the mean and covariance matrix of the queries") + mu, cov = self.apply_avg_rope(module, self.mu, self.cov, hidden_states.shape[1]) + return mu, cov + + + def _maybe_load_stats_from_hub(self, model: PreTrainedModel): + """Load statistics from the Hugging Face Hub.""" + stats_id = ExpectedAttentionStats( + model_name=model.config.name_or_path, + num_layers=model.config.num_hidden_layers, + num_heads=model.config.num_attention_heads, + head_dim=model.config.head_dim, + dataset_name=self.stats_dataset, + num_samples=self.n_samples, + sample_seq_len=self.sample_seq_len, + n_sink=self.n_sink, + ).stats_id() + try: + return ExpectedAttentionStats.from_pretrained(stats_id) + except ValueError: + raise ValueError(f"No statistics found for model {stats_id} on the Hub. Please compute them first.") + + + def __post_init_from_model__(self, model): + """ + Automatically load or compute query statistics for the model. + """ + if self.stats_folder is not None: + stats = ExpectedAttentionStats.from_pretrained(self.stats_folder) + else: + print(f"Loading statistics from the Hub") + stats = self._maybe_load_stats_from_hub(model) + print(f"Loaded statistics from the Hub") + self.mu = stats.query_mean.data.to(model.device) + self.cov = stats.query_cov.data.to(model.device) + + @contextmanager + def __call__(self, model): + self.__post_init_from_model__(model) + with super().__call__(model): + yield + + + +class ExpectedAttentionStats(torch.nn.Module, PyTorchModelHubMixin): + """ + Module that stores the mean and covariance matrix of the queries, possibly uploaded to the HF hub. + + """ + + def __init__( + self, + num_layers: int, + num_heads: int, + head_dim: int, + dataset_name: str, + model_name: str, + num_samples: int, + sample_seq_len: int, + n_sink: int, + ): + super().__init__() + self.query_mean = torch.nn.Parameter(torch.zeros(num_layers, num_heads, head_dim)) + self.query_cov = torch.nn.Parameter(torch.zeros(num_layers, num_heads, head_dim, head_dim)) + self.dataset_name = dataset_name + self.model_name = model_name + self.num_samples = num_samples + self.sample_seq_len = sample_seq_len + self.n_sink = n_sink + + def stats_id(self) -> str: + """Generate the statistics ID for the model and configuration.""" + return f"alessiodevoto/exp_att_stats_{self.model_name.replace('/', '_')}_{self.dataset_name.replace('/', '_')}_{self.num_samples}_{self.sample_seq_len}_{self.n_sink}" + + +# The code below is used to collect statistics on a dataset, and is not used in the press. + +@contextmanager +def patch_rotary_embedding(model): + """ + A context manager to dynamically patch the `apply_rotary_pos_emb` function + for any supported model architecture. It captures the query states before + rotary embeddings are applied. + + Args: + model (PreTrainedModel): The transformer model instance. + + Yields: + list: A list that will be populated with the captured query tensors. + """ + # Dynamically find the model's specific "modeling" module + try: + module_path = model.__class__.__module__ + modeling_module = importlib.import_module(module_path) + except Exception as e: + raise RuntimeError(f"Failed to import module for {model.__class__.__name__}: {e}") + + # Check for the target function and save the original + target_function = "apply_rotary_pos_emb" + if not hasattr(modeling_module, target_function): + raise AttributeError( + f"Model architecture '{model.config.model_type}' is not supported. " + f"The module '{module_path}' does not contain '{target_function}'." + ) + + original_function = getattr(modeling_module, target_function) + + captured_tensors = [] + + def patched_function(q_embed, k_embed, *args, **kwargs): + # Capture the query tensor before RoPE is applied + captured_tensors.append(q_embed.detach()) + q_embed, k_embed = original_function(q_embed, k_embed, *args, **kwargs) + return q_embed, k_embed + + # Apply the patch + setattr(modeling_module, target_function, patched_function) + + try: + yield captured_tensors + finally: + setattr(modeling_module, target_function, original_function) + + +@torch.inference_mode() +def collect_queries( + model: PreTrainedModel, + dataset_name: str, + num_samples: int, + q_len: int, + n_sink: int, + return_stats: bool = False, + text_column: str = "chapter", +) -> list[torch.Tensor]: + """ + Collects query representations from a transformer model using a calibration dataset. + + This function runs the model on a small number of samples from the "kmfoda/booksum" dataset, + capturing the query tensors after rotary positional embeddings are applied. It trims the + input text to a maximum length (`q_len`), skips the first `n_sink` tokens (to avoid outliers), + and returns the collected queries. + + Args: + model (PreTrainedModel): The transformer model instance. + dataset_name (str): Name of the dataset to use for collecting statistics. + num_samples (int): Number of samples to use from the calibration dataset. + q_len (int): Maximum sequence length to consider for each sample. + n_sink (int): Number of initial tokens to exclude from the collected queries. + return_stats (bool): Whether to return the mean and covariance of the queries. + text_column (str): Name of the column in the dataset containing the text to tokenize. + + Returns: + list or tuple: + collected_queries (list): List of query tensors, each of shape (num_layers, num_heads, seq_len, head_dim) + mean_query (torch.Tensor): Mean query vector for each layer and head. + cov_query (torch.Tensor): Covariance matrix of queries for each layer and head. + If return_stats is False, only the list of query tensors is returned. + """ + + # Load dataset and tokenizer + tokenizer = AutoTokenizer.from_pretrained(model.config.name_or_path) + dataset = load_dataset(dataset_name, split=f"train[:{num_samples}]") + + # Cut to max q_len + dataset = dataset.map(lambda x: {text_column: x[text_column][:q_len]}) + + collected_queries = [] + for text in tqdm(dataset[text_column], desc="Collecting queries"): + inputs = tokenizer(text, return_tensors="pt").to(model.device) + with patch_rotary_embedding(model) as captured_queries: + model(**inputs) + collected_queries.append(torch.cat(captured_queries, dim=0)[:, :, n_sink:, :]) + + if return_stats: + cat_queries = torch.cat(collected_queries, dim=-2) + mean_query = cat_queries.mean(dim=-2) + # compute covariance manually + centered_queries = cat_queries - mean_query.unsqueeze(-2) + N = cat_queries.shape[-2] + cov_query = (centered_queries.transpose(-2, -1) @ centered_queries) / (N - 1) + return collected_queries, mean_query, cov_query + else: + return collected_queries + + +if __name__ == "__main__": + import argparse + parser = argparse.ArgumentParser() + parser.add_argument("--model_name", type=str, default="meta-llama/Llama-3.2-1B-Instruct") + parser.add_argument("--output_path", type=str, default=".") + parser.add_argument("--dataset_name", type=str, default="kmfoda/booksum") + parser.add_argument("--num_samples", type=int, default=100) + parser.add_argument("--sample_seq_len", type=int, default=1000) + parser.add_argument("--n_sink", type=int, default=4) + parser.add_argument("--text_column", type=str, default="chapter") + parser.add_argument("--device_map", type=str, default="auto") + args = parser.parse_args() + model = AutoModelForCausalLM.from_pretrained(args.model_name, device_map=args.device_map, torch_dtype=torch.bfloat16).eval() + _, mu, cov = collect_queries(model, args.dataset_name, args.num_samples, args.sample_seq_len, args.n_sink, return_stats=True) + + stats = ExpectedAttentionStats( + num_layers=model.config.num_hidden_layers, + num_heads=model.config.num_attention_heads, + head_dim=model.config.head_dim, + dataset_name=args.dataset_name, + model_name=args.model_name, + num_samples=args.num_samples, + sample_seq_len=args.sample_seq_len, + n_sink=args.n_sink, + ) + output_path = os.path.join(args.output_path, stats.stats_id()) + stats.save_pretrained(output_path) \ No newline at end of file From b53e9e2cdf74c1fe8bc7a50cc9719cb5275809ba Mon Sep 17 00:00:00 2001 From: alessiodevoto Date: Wed, 27 Aug 2025 12:40:49 +0000 Subject: [PATCH 02/10] style Signed-off-by: alessiodevoto --- kvpress/__init__.py | 2 +- kvpress/presses/expected_attention_press.py | 1 - .../presses/expected_attention_with_stats.py | 69 +++++++++---------- 3 files changed, 34 insertions(+), 38 deletions(-) diff --git a/kvpress/__init__.py b/kvpress/__init__.py index e9b07e15..cccd1f02 100644 --- a/kvpress/__init__.py +++ b/kvpress/__init__.py @@ -13,6 +13,7 @@ from kvpress.presses.criticalkv_press import CriticalAdaKVPress, CriticalKVPress from kvpress.presses.duo_attention_press import DuoAttentionPress from kvpress.presses.expected_attention_press import ExpectedAttentionPress +from kvpress.presses.expected_attention_with_stats import ExpectedAttentionStatsPress from kvpress.presses.finch_press import FinchPress from kvpress.presses.key_rerotation_press import KeyRerotationPress from kvpress.presses.keydiff_press import KeyDiffPress @@ -30,7 +31,6 @@ from kvpress.presses.streaming_llm_press import StreamingLLMPress from kvpress.presses.think_press import ThinKPress from kvpress.presses.tova_press import TOVAPress -from kvpress.presses.expected_attention_with_stats import ExpectedAttentionStatsPress # Patch the attention functions to support head-wise compression patch_attention_functions() diff --git a/kvpress/presses/expected_attention_press.py b/kvpress/presses/expected_attention_press.py index 618f12e8..4162a255 100644 --- a/kvpress/presses/expected_attention_press.py +++ b/kvpress/presses/expected_attention_press.py @@ -133,7 +133,6 @@ def apply_avg_rope(self, module: nn.Module, mu: torch.Tensor, cov: torch.Tensor, cov = torch.matmul(R, torch.matmul(cov, R.T)) return mu, cov - def score( self, module: nn.Module, diff --git a/kvpress/presses/expected_attention_with_stats.py b/kvpress/presses/expected_attention_with_stats.py index c198d5a9..3838ce48 100644 --- a/kvpress/presses/expected_attention_with_stats.py +++ b/kvpress/presses/expected_attention_with_stats.py @@ -1,19 +1,18 @@ # SPDX-FileCopyrightText: Copyright (c) 1993-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. # SPDX-License-Identifier: Apache-2.0 +import importlib import os +from contextlib import contextmanager from dataclasses import dataclass, field from typing import Optional -import importlib -from contextlib import contextmanager -import torch -from torch import nn +import torch from datasets import load_dataset -from tqdm import tqdm -from transformers import AutoTokenizer, PreTrainedModel from huggingface_hub import PyTorchModelHubMixin -from transformers import AutoModelForCausalLM +from torch import nn +from tqdm import tqdm +from transformers import AutoModelForCausalLM, AutoTokenizer, PreTrainedModel from kvpress.presses.expected_attention_press import ExpectedAttentionPress @@ -22,8 +21,8 @@ class ExpectedAttentionStatsPress(ExpectedAttentionPress): """ Expected attention press that automatically loads pre-computed query statistics. - - + + Parameters ---------- compression_ratio : float, default=0.0 @@ -52,20 +51,16 @@ class ExpectedAttentionStatsPress(ExpectedAttentionPress): stats_dataset: str = "kmfoda/booksum" stats_folder: Optional[str] = None - mu: torch.Tensor = field(init=False, default=None) # initialized in __post_init_from_model__ cov: torch.Tensor = field(init=False, default=None) # initialized in __post_init_from_model__ - - + def get_query_statistics(self, module: nn.Module, hidden_states: torch.Tensor): """ Compute the mean and covariance matrix of the queries and apply average RoPE to them. """ - print(f"Applying average RoPE to the mean and covariance matrix of the queries") mu, cov = self.apply_avg_rope(module, self.mu, self.cov, hidden_states.shape[1]) return mu, cov - def _maybe_load_stats_from_hub(self, model: PreTrainedModel): """Load statistics from the Hugging Face Hub.""" stats_id = ExpectedAttentionStats( @@ -82,7 +77,6 @@ def _maybe_load_stats_from_hub(self, model: PreTrainedModel): return ExpectedAttentionStats.from_pretrained(stats_id) except ValueError: raise ValueError(f"No statistics found for model {stats_id} on the Hub. Please compute them first.") - def __post_init_from_model__(self, model): """ @@ -91,9 +85,7 @@ def __post_init_from_model__(self, model): if self.stats_folder is not None: stats = ExpectedAttentionStats.from_pretrained(self.stats_folder) else: - print(f"Loading statistics from the Hub") stats = self._maybe_load_stats_from_hub(model) - print(f"Loaded statistics from the Hub") self.mu = stats.query_mean.data.to(model.device) self.cov = stats.query_cov.data.to(model.device) @@ -103,7 +95,6 @@ def __call__(self, model): with super().__call__(model): yield - class ExpectedAttentionStats(torch.nn.Module, PyTorchModelHubMixin): """ @@ -112,15 +103,15 @@ class ExpectedAttentionStats(torch.nn.Module, PyTorchModelHubMixin): """ def __init__( - self, - num_layers: int, - num_heads: int, - head_dim: int, - dataset_name: str, - model_name: str, - num_samples: int, - sample_seq_len: int, - n_sink: int, + self, + num_layers: int, + num_heads: int, + head_dim: int, + dataset_name: str, + model_name: str, + num_samples: int, + sample_seq_len: int, + n_sink: int, ): super().__init__() self.query_mean = torch.nn.Parameter(torch.zeros(num_layers, num_heads, head_dim)) @@ -130,14 +121,15 @@ def __init__( self.num_samples = num_samples self.sample_seq_len = sample_seq_len self.n_sink = n_sink - + def stats_id(self) -> str: """Generate the statistics ID for the model and configuration.""" - return f"alessiodevoto/exp_att_stats_{self.model_name.replace('/', '_')}_{self.dataset_name.replace('/', '_')}_{self.num_samples}_{self.sample_seq_len}_{self.n_sink}" + return f"alessiodevoto/exp_att_stats_{self.model_name.replace('/', '_')}_{self.dataset_name.replace('/', '_')}_{self.num_samples}_{self.sample_seq_len}_{self.n_sink}" # noqa: E501 # The code below is used to collect statistics on a dataset, and is not used in the press. + @contextmanager def patch_rotary_embedding(model): """ @@ -187,14 +179,14 @@ def patched_function(q_embed, k_embed, *args, **kwargs): @torch.inference_mode() def collect_queries( - model: PreTrainedModel, + model: PreTrainedModel, dataset_name: str, - num_samples: int, - q_len: int, + num_samples: int, + q_len: int, n_sink: int, return_stats: bool = False, text_column: str = "chapter", -) -> list[torch.Tensor]: +) -> list[torch.Tensor] | tuple[list[torch.Tensor], torch.Tensor, torch.Tensor]: """ Collects query representations from a transformer model using a calibration dataset. @@ -248,6 +240,7 @@ def collect_queries( if __name__ == "__main__": import argparse + parser = argparse.ArgumentParser() parser.add_argument("--model_name", type=str, default="meta-llama/Llama-3.2-1B-Instruct") parser.add_argument("--output_path", type=str, default=".") @@ -258,8 +251,12 @@ def collect_queries( parser.add_argument("--text_column", type=str, default="chapter") parser.add_argument("--device_map", type=str, default="auto") args = parser.parse_args() - model = AutoModelForCausalLM.from_pretrained(args.model_name, device_map=args.device_map, torch_dtype=torch.bfloat16).eval() - _, mu, cov = collect_queries(model, args.dataset_name, args.num_samples, args.sample_seq_len, args.n_sink, return_stats=True) + model = AutoModelForCausalLM.from_pretrained( + args.model_name, device_map=args.device_map, torch_dtype=torch.bfloat16 + ).eval() + _, mu, cov = collect_queries( + model, args.dataset_name, args.num_samples, args.sample_seq_len, args.n_sink, return_stats=True + ) stats = ExpectedAttentionStats( num_layers=model.config.num_hidden_layers, @@ -272,4 +269,4 @@ def collect_queries( n_sink=args.n_sink, ) output_path = os.path.join(args.output_path, stats.stats_id()) - stats.save_pretrained(output_path) \ No newline at end of file + stats.save_pretrained(output_path) From a8ec695e0981e7c87df2086eeac440e8b1537168 Mon Sep 17 00:00:00 2001 From: alessiodevoto Date: Wed, 27 Aug 2025 13:09:27 +0000 Subject: [PATCH 03/10] style Signed-off-by: alessiodevoto --- .../presses/expected_attention_with_stats.py | 35 +++++++++++-------- 1 file changed, 20 insertions(+), 15 deletions(-) diff --git a/kvpress/presses/expected_attention_with_stats.py b/kvpress/presses/expected_attention_with_stats.py index 3838ce48..93d35f17 100644 --- a/kvpress/presses/expected_attention_with_stats.py +++ b/kvpress/presses/expected_attention_with_stats.py @@ -56,11 +56,22 @@ class ExpectedAttentionStatsPress(ExpectedAttentionPress): def get_query_statistics(self, module: nn.Module, hidden_states: torch.Tensor): """ - Compute the mean and covariance matrix of the queries and apply average RoPE to them. + Override the parent method to use the pre-computed query statistics. """ mu, cov = self.apply_avg_rope(module, self.mu, self.cov, hidden_states.shape[1]) return mu, cov + def __post_init_from_model__(self, model): + """ + Automatically load or compute query statistics for the model. + """ + if self.stats_folder is not None: + stats = ExpectedAttentionStats.from_pretrained(self.stats_folder) + else: + stats = self._maybe_load_stats_from_hub(model) + self.mu = stats.query_mean.data.to(model.device) + self.cov = stats.query_cov.data.to(model.device) + def _maybe_load_stats_from_hub(self, model: PreTrainedModel): """Load statistics from the Hugging Face Hub.""" stats_id = ExpectedAttentionStats( @@ -76,18 +87,13 @@ def _maybe_load_stats_from_hub(self, model: PreTrainedModel): try: return ExpectedAttentionStats.from_pretrained(stats_id) except ValueError: - raise ValueError(f"No statistics found for model {stats_id} on the Hub. Please compute them first.") - - def __post_init_from_model__(self, model): - """ - Automatically load or compute query statistics for the model. - """ - if self.stats_folder is not None: - stats = ExpectedAttentionStats.from_pretrained(self.stats_folder) - else: - stats = self._maybe_load_stats_from_hub(model) - self.mu = stats.query_mean.data.to(model.device) - self.cov = stats.query_cov.data.to(model.device) + raise ValueError( + f"No statistics found for model {stats_id} on the Hub. Please compute them first. " + "You can do so by running the following code: " + "```" + "python expected_attention_with_stats.py --model_name " + "```" + ) @contextmanager def __call__(self, model): @@ -99,7 +105,6 @@ def __call__(self, model): class ExpectedAttentionStats(torch.nn.Module, PyTorchModelHubMixin): """ Module that stores the mean and covariance matrix of the queries, possibly uploaded to the HF hub. - """ def __init__( @@ -127,7 +132,7 @@ def stats_id(self) -> str: return f"alessiodevoto/exp_att_stats_{self.model_name.replace('/', '_')}_{self.dataset_name.replace('/', '_')}_{self.num_samples}_{self.sample_seq_len}_{self.n_sink}" # noqa: E501 -# The code below is used to collect statistics on a dataset, and is not used in the press. +# The code below is used to collect statistics on a dataset. @contextmanager From 2013f0ba919a5bf8adc4968da71af4613cbf499b Mon Sep 17 00:00:00 2001 From: alessiodevoto Date: Wed, 27 Aug 2025 14:39:50 +0000 Subject: [PATCH 04/10] update fire and query logic Signed-off-by: alessiodevoto --- .../presses/expected_attention_with_stats.py | 103 ++++++++++-------- 1 file changed, 57 insertions(+), 46 deletions(-) diff --git a/kvpress/presses/expected_attention_with_stats.py b/kvpress/presses/expected_attention_with_stats.py index 93d35f17..768a6676 100644 --- a/kvpress/presses/expected_attention_with_stats.py +++ b/kvpress/presses/expected_attention_with_stats.py @@ -7,6 +7,7 @@ from dataclasses import dataclass, field from typing import Optional +import fire import torch from datasets import load_dataset from huggingface_hub import PyTorchModelHubMixin @@ -59,18 +60,19 @@ def get_query_statistics(self, module: nn.Module, hidden_states: torch.Tensor): Override the parent method to use the pre-computed query statistics. """ mu, cov = self.apply_avg_rope(module, self.mu, self.cov, hidden_states.shape[1]) - return mu, cov + return mu[module.layer_idx].unsqueeze(0), cov[module.layer_idx].unsqueeze(0) def __post_init_from_model__(self, model): """ Automatically load or compute query statistics for the model. """ - if self.stats_folder is not None: - stats = ExpectedAttentionStats.from_pretrained(self.stats_folder) - else: - stats = self._maybe_load_stats_from_hub(model) - self.mu = stats.query_mean.data.to(model.device) - self.cov = stats.query_cov.data.to(model.device) + if self.mu is None and self.cov is None: + if self.stats_folder is not None: + stats = ExpectedAttentionStats.from_pretrained(self.stats_folder) + else: + stats = self._maybe_load_stats_from_hub(model) + self.mu = stats.query_mean.data.to(model.device, dtype=model.dtype) + self.cov = stats.query_cov.data.to(model.device, dtype=model.dtype) def _maybe_load_stats_from_hub(self, model: PreTrainedModel): """Load statistics from the Hugging Face Hub.""" @@ -189,9 +191,8 @@ def collect_queries( num_samples: int, q_len: int, n_sink: int, - return_stats: bool = False, text_column: str = "chapter", -) -> list[torch.Tensor] | tuple[list[torch.Tensor], torch.Tensor, torch.Tensor]: +) -> tuple[list[torch.Tensor], torch.Tensor, torch.Tensor]: """ Collects query representations from a transformer model using a calibration dataset. @@ -206,7 +207,6 @@ def collect_queries( num_samples (int): Number of samples to use from the calibration dataset. q_len (int): Maximum sequence length to consider for each sample. n_sink (int): Number of initial tokens to exclude from the collected queries. - return_stats (bool): Whether to return the mean and covariance of the queries. text_column (str): Name of the column in the dataset containing the text to tokenize. Returns: @@ -214,7 +214,6 @@ def collect_queries( collected_queries (list): List of query tensors, each of shape (num_layers, num_heads, seq_len, head_dim) mean_query (torch.Tensor): Mean query vector for each layer and head. cov_query (torch.Tensor): Covariance matrix of queries for each layer and head. - If return_stats is False, only the list of query tensors is returned. """ # Load dataset and tokenizer @@ -231,47 +230,59 @@ def collect_queries( model(**inputs) collected_queries.append(torch.cat(captured_queries, dim=0)[:, :, n_sink:, :]) - if return_stats: - cat_queries = torch.cat(collected_queries, dim=-2) - mean_query = cat_queries.mean(dim=-2) - # compute covariance manually - centered_queries = cat_queries - mean_query.unsqueeze(-2) - N = cat_queries.shape[-2] - cov_query = (centered_queries.transpose(-2, -1) @ centered_queries) / (N - 1) - return collected_queries, mean_query, cov_query - else: - return collected_queries + cat_queries = torch.cat(collected_queries, dim=-2) + mean_query = cat_queries.mean(dim=-2) + # compute covariance manually + centered_queries = cat_queries - mean_query.unsqueeze(-2) + N = cat_queries.shape[-2] + cov_query = (centered_queries.transpose(-2, -1) @ centered_queries) / (N - 1) + return collected_queries, mean_query, cov_query + + +def main( + model_name: str = "meta-llama/Llama-3.1-8B-Instruct", + output_path: str = ".", + dataset_name: str = "kmfoda/booksum", + num_samples: int = 100, + sample_seq_len: int = 1000, + n_sink: int = 4, + text_column: str = "chapter", + device_map: str = "auto", +): + """ + Collect query statistics for a transformer model and save them. + Args: + model_name: Name of the model to collect statistics for + output_path: Directory to save the statistics + dataset_name: Dataset to use for collecting statistics + num_samples: Number of samples to use from the dataset + sample_seq_len: Sequence length for each sample + n_sink: Number of initial tokens to exclude + text_column: Column name containing text in the dataset + device_map: Device mapping for the model + """ + model = AutoModelForCausalLM.from_pretrained(model_name, device_map=device_map, torch_dtype=torch.bfloat16).eval() -if __name__ == "__main__": - import argparse - - parser = argparse.ArgumentParser() - parser.add_argument("--model_name", type=str, default="meta-llama/Llama-3.2-1B-Instruct") - parser.add_argument("--output_path", type=str, default=".") - parser.add_argument("--dataset_name", type=str, default="kmfoda/booksum") - parser.add_argument("--num_samples", type=int, default=100) - parser.add_argument("--sample_seq_len", type=int, default=1000) - parser.add_argument("--n_sink", type=int, default=4) - parser.add_argument("--text_column", type=str, default="chapter") - parser.add_argument("--device_map", type=str, default="auto") - args = parser.parse_args() - model = AutoModelForCausalLM.from_pretrained( - args.model_name, device_map=args.device_map, torch_dtype=torch.bfloat16 - ).eval() - _, mu, cov = collect_queries( - model, args.dataset_name, args.num_samples, args.sample_seq_len, args.n_sink, return_stats=True - ) + _, mu, cov = collect_queries(model, dataset_name, num_samples, sample_seq_len, n_sink, text_column) stats = ExpectedAttentionStats( num_layers=model.config.num_hidden_layers, num_heads=model.config.num_attention_heads, head_dim=model.config.head_dim, - dataset_name=args.dataset_name, - model_name=args.model_name, - num_samples=args.num_samples, - sample_seq_len=args.sample_seq_len, - n_sink=args.n_sink, + dataset_name=dataset_name, + model_name=model_name, + num_samples=num_samples, + sample_seq_len=sample_seq_len, + n_sink=n_sink, ) - output_path = os.path.join(args.output_path, stats.stats_id()) + stats.query_mean.data = mu + stats.query_cov.data = cov + + output_path = os.path.join(output_path, stats.stats_id()) stats.save_pretrained(output_path) + print(f"Statistics saved to: {output_path}") + + +if __name__ == "__main__": + fire.Fire(main) From aee448303f191e803018e401f21843c17efd7229 Mon Sep 17 00:00:00 2001 From: alessiodevoto Date: Wed, 27 Aug 2025 14:47:33 +0000 Subject: [PATCH 05/10] minor Signed-off-by: alessiodevoto --- kvpress/presses/expected_attention_with_stats.py | 16 ++++++++-------- 1 file changed, 8 insertions(+), 8 deletions(-) diff --git a/kvpress/presses/expected_attention_with_stats.py b/kvpress/presses/expected_attention_with_stats.py index 768a6676..30938349 100644 --- a/kvpress/presses/expected_attention_with_stats.py +++ b/kvpress/presses/expected_attention_with_stats.py @@ -38,9 +38,9 @@ class ExpectedAttentionStatsPress(ExpectedAttentionPress): Whether to rescale scores using value vector norms. epsilon : float, default=0.0 Small constant added to scores before value norm rescaling. - stats_dataset : str, default="kmfoda/booksum" + dataset_name : str, default="kmfoda/booksum" Dataset used to compute the statistics. - n_samples : int, default=100 + num_samples : int, default=100 Number of samples used to compute the statistics. sample_seq_len : int, default=1000 Sequence length used to compute the statistics. @@ -48,8 +48,8 @@ class ExpectedAttentionStatsPress(ExpectedAttentionPress): # Override parent defaults to enable stats by default sample_seq_len: int = 1000 - n_samples: int = 100 - stats_dataset: str = "kmfoda/booksum" + num_samples: int = 100 + dataset_name: str = "kmfoda/booksum" stats_folder: Optional[str] = None mu: torch.Tensor = field(init=False, default=None) # initialized in __post_init_from_model__ @@ -81,8 +81,8 @@ def _maybe_load_stats_from_hub(self, model: PreTrainedModel): num_layers=model.config.num_hidden_layers, num_heads=model.config.num_attention_heads, head_dim=model.config.head_dim, - dataset_name=self.stats_dataset, - num_samples=self.n_samples, + dataset_name=self.dataset_name, + num_samples=self.num_samples, sample_seq_len=self.sample_seq_len, n_sink=self.n_sink, ).stats_id() @@ -189,7 +189,7 @@ def collect_queries( model: PreTrainedModel, dataset_name: str, num_samples: int, - q_len: int, + sample_seq_len: int, n_sink: int, text_column: str = "chapter", ) -> tuple[list[torch.Tensor], torch.Tensor, torch.Tensor]: @@ -221,7 +221,7 @@ def collect_queries( dataset = load_dataset(dataset_name, split=f"train[:{num_samples}]") # Cut to max q_len - dataset = dataset.map(lambda x: {text_column: x[text_column][:q_len]}) + dataset = dataset.map(lambda x: {text_column: x[text_column][:sample_seq_len]}) collected_queries = [] for text in tqdm(dataset[text_column], desc="Collecting queries"): From 29444021026d48d5a75de3c1616f039506d16a31 Mon Sep 17 00:00:00 2001 From: alessiodevoto Date: Wed, 27 Aug 2025 15:57:11 +0000 Subject: [PATCH 06/10] fix index layer Signed-off-by: alessiodevoto --- kvpress/presses/expected_attention_with_stats.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/kvpress/presses/expected_attention_with_stats.py b/kvpress/presses/expected_attention_with_stats.py index 30938349..ca7fa2ba 100644 --- a/kvpress/presses/expected_attention_with_stats.py +++ b/kvpress/presses/expected_attention_with_stats.py @@ -59,8 +59,9 @@ def get_query_statistics(self, module: nn.Module, hidden_states: torch.Tensor): """ Override the parent method to use the pre-computed query statistics. """ - mu, cov = self.apply_avg_rope(module, self.mu, self.cov, hidden_states.shape[1]) - return mu[module.layer_idx].unsqueeze(0), cov[module.layer_idx].unsqueeze(0) + q_len = hidden_states.shape[1] + mu, cov = self.apply_avg_rope(module, self.mu[module.layer_idx], self.cov[module.layer_idx], q_len) + return mu.unsqueeze(0), cov.unsqueeze(0) def __post_init_from_model__(self, model): """ From d9803b64c6d0cef06e889c3bfe8a7e37fac28c82 Mon Sep 17 00:00:00 2001 From: alessiodevoto Date: Thu, 28 Aug 2025 08:58:01 +0000 Subject: [PATCH 07/10] update Signed-off-by: alessiodevoto --- kvpress/presses/expected_attention_with_stats.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/kvpress/presses/expected_attention_with_stats.py b/kvpress/presses/expected_attention_with_stats.py index ca7fa2ba..90dc47cb 100644 --- a/kvpress/presses/expected_attention_with_stats.py +++ b/kvpress/presses/expected_attention_with_stats.py @@ -60,7 +60,8 @@ def get_query_statistics(self, module: nn.Module, hidden_states: torch.Tensor): Override the parent method to use the pre-computed query statistics. """ q_len = hidden_states.shape[1] - mu, cov = self.apply_avg_rope(module, self.mu[module.layer_idx], self.cov[module.layer_idx], q_len) + layer_idx = module.layer_idx + mu, cov = self.apply_avg_rope(module, self.mu[layer_idx], self.cov[layer_idx], q_len) # type: ignore return mu.unsqueeze(0), cov.unsqueeze(0) def __post_init_from_model__(self, model): @@ -172,7 +173,7 @@ def patch_rotary_embedding(model): def patched_function(q_embed, k_embed, *args, **kwargs): # Capture the query tensor before RoPE is applied - captured_tensors.append(q_embed.detach()) + captured_tensors.append(q_embed.detach().cpu()) q_embed, k_embed = original_function(q_embed, k_embed, *args, **kwargs) return q_embed, k_embed From e4157b2f034f19691d2eb9ef45a770cfdc5b4877 Mon Sep 17 00:00:00 2001 From: alessiodevoto Date: Thu, 28 Aug 2025 09:58:03 +0000 Subject: [PATCH 08/10] tests Signed-off-by: alessiodevoto --- kvpress/presses/expected_attention_with_stats.py | 7 ++++++- tests/default_presses.py | 2 ++ tests/presses/test_ea_with_stats.py | 8 ++++++++ 3 files changed, 16 insertions(+), 1 deletion(-) create mode 100644 tests/presses/test_ea_with_stats.py diff --git a/kvpress/presses/expected_attention_with_stats.py b/kvpress/presses/expected_attention_with_stats.py index 90dc47cb..195f6b8d 100644 --- a/kvpress/presses/expected_attention_with_stats.py +++ b/kvpress/presses/expected_attention_with_stats.py @@ -10,7 +10,7 @@ import fire import torch from datasets import load_dataset -from huggingface_hub import PyTorchModelHubMixin +from huggingface_hub import PyTorchModelHubMixin, get_collection from torch import nn from tqdm import tqdm from transformers import AutoModelForCausalLM, AutoTokenizer, PreTrainedModel @@ -64,6 +64,11 @@ def get_query_statistics(self, module: nn.Module, hidden_states: torch.Tensor): mu, cov = self.apply_avg_rope(module, self.mu[layer_idx], self.cov[layer_idx], q_len) # type: ignore return mu.unsqueeze(0), cov.unsqueeze(0) + @staticmethod + def available_stats(): + collection = get_collection("alessiodevoto/expectedattentionstats-68b0248d519303713320e2cf") + return [x.item_id for x in collection.items] + def __post_init_from_model__(self, model): """ Automatically load or compute query statistics for the model. diff --git a/tests/default_presses.py b/tests/default_presses.py index fc7635cc..5c954766 100644 --- a/tests/default_presses.py +++ b/tests/default_presses.py @@ -6,6 +6,7 @@ from kvpress import ( DuoAttentionPress, ExpectedAttentionPress, + ExpectedAttentionStatsPress, KeyDiffPress, KnormPress, KVzipPress, @@ -34,6 +35,7 @@ def load_attention_pattern(model): {"cls": TestDuoAttentionPress, "kwargs": [{"head_compression_ratio": 0.2}, {"head_compression_ratio": 0.8}]}, {"cls": KnormPress, "kwargs": [{"compression_ratio": 0.2}, {"compression_ratio": 0.8}]}, {"cls": ExpectedAttentionPress, "kwargs": [{"compression_ratio": 0.2}, {"compression_ratio": 0.8}]}, + {"cls": ExpectedAttentionStatsPress, "kwargs": [{"compression_ratio": 0.2}, {"compression_ratio": 0.8}]}, {"cls": RandomPress, "kwargs": [{"compression_ratio": 0.2}, {"compression_ratio": 0.8}]}, {"cls": StreamingLLMPress, "kwargs": [{"compression_ratio": 0.2}, {"compression_ratio": 0.8}]}, {"cls": QFilterPress, "kwargs": [{"compression_ratio": 0.2}, {"compression_ratio": 0.8}]}, diff --git a/tests/presses/test_ea_with_stats.py b/tests/presses/test_ea_with_stats.py new file mode 100644 index 00000000..d76c6853 --- /dev/null +++ b/tests/presses/test_ea_with_stats.py @@ -0,0 +1,8 @@ +# SPDX-FileCopyrightText: Copyright (c) 1993-2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +from kvpress.presses.expected_attention_with_stats import ExpectedAttentionStats, ExpectedAttentionStatsPress + + +def test_load_stats(): + for stats_id in ExpectedAttentionStatsPress.available_stats(): + ExpectedAttentionStats.from_pretrained(stats_id) From 7f939614a3a43ed9a3921b7ac602339b28a2dfa8 Mon Sep 17 00:00:00 2001 From: alessiodevoto Date: Thu, 28 Aug 2025 10:50:16 +0000 Subject: [PATCH 09/10] fix deps Signed-off-by: alessiodevoto --- pyproject.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pyproject.toml b/pyproject.toml index bf9e226c..6a4b5c16 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -22,6 +22,7 @@ dependencies = [ "accelerate>=1.0.0,<2", "requests>=2.32.3,<3", "cachetools>=5.5.2,<6", + "fire>=0.6.0,<0.7", ] [project.optional-dependencies] @@ -30,7 +31,6 @@ eval = [ "nltk>=3.9.1,<4", "tqdm>=4.66.4,<5", "scipy>=1.13.1,<2", - "fire>=0.6.0,<0.7", "bert-score>=0.3.13,<0.4", ] flash-attn = [ From bc27dd8f8ca43a9cc67338f38fdcc42101eb26df Mon Sep 17 00:00:00 2001 From: alessiodevoto Date: Thu, 28 Aug 2025 14:36:51 +0000 Subject: [PATCH 10/10] polish Signed-off-by: alessiodevoto --- kvpress/presses/expected_attention_press.py | 10 ---------- 1 file changed, 10 deletions(-) diff --git a/kvpress/presses/expected_attention_press.py b/kvpress/presses/expected_attention_press.py index 4162a255..bf655a37 100644 --- a/kvpress/presses/expected_attention_press.py +++ b/kvpress/presses/expected_attention_press.py @@ -83,16 +83,6 @@ def get_query_statistics(self, module: nn.Module, hidden_states: torch.Tensor): # Apply RoPE to the mean and covariance matrix of the queries mu, cov = self.apply_avg_rope(module, mu, cov, q_len) - # Instead of using the average rotation matrix, we could use a mixture of gaussian statistics to - # estimate mean and covariance. Estimation is better, but end-to-end performance was lower. - # mu = torch.einsum("bhj, fij -> bhfi", mu, R) - # mean_mu = mu.mean(dim=2, keepdim=True) - # if self.use_covariance: - # cov = torch.einsum("fki, bhkl, fjl -> bhfij", R, cov, R) - # cov = cov.mean(dim=2) - # cov += torch.einsum("bhfi, bhfj -> bhji", mu - mean_mu, mu - mean_mu) / self.n_future_positions - # mu = mean_mu.squeeze(2) - return mu, cov def apply_avg_rope(self, module: nn.Module, mu: torch.Tensor, cov: torch.Tensor, q_len: int):