diff --git a/kvpress/__init__.py b/kvpress/__init__.py index bae791d8..cccd1f02 100644 --- a/kvpress/__init__.py +++ b/kvpress/__init__.py @@ -13,6 +13,7 @@ from kvpress.presses.criticalkv_press import CriticalAdaKVPress, CriticalKVPress from kvpress.presses.duo_attention_press import DuoAttentionPress from kvpress.presses.expected_attention_press import ExpectedAttentionPress +from kvpress.presses.expected_attention_with_stats import ExpectedAttentionStatsPress from kvpress.presses.finch_press import FinchPress from kvpress.presses.key_rerotation_press import KeyRerotationPress from kvpress.presses.keydiff_press import KeyDiffPress @@ -63,4 +64,5 @@ "BlockPress", "KeyDiffPress", "KVzipPress", + "ExpectedAttentionStatsPress", ] diff --git a/kvpress/presses/expected_attention_press.py b/kvpress/presses/expected_attention_press.py index e22e2b02..bf655a37 100644 --- a/kvpress/presses/expected_attention_press.py +++ b/kvpress/presses/expected_attention_press.py @@ -65,7 +65,6 @@ def get_query_statistics(self, module: nn.Module, hidden_states: torch.Tensor): """ q_len = hidden_states.shape[1] - head_dim = module.head_dim # Remove first hidden_states that likely contain outliers h = hidden_states[:, self.n_sink :] @@ -81,34 +80,47 @@ def get_query_statistics(self, module: nn.Module, hidden_states: torch.Tensor): cov = torch.einsum("bnsi,bnsj->bnij", centered_states, centered_states) / h.shape[1] mu = mu.squeeze(2) - # RoPE rotation matrix on next n_future_positions + # Apply RoPE to the mean and covariance matrix of the queries + mu, cov = self.apply_avg_rope(module, mu, cov, q_len) + + return mu, cov + + def apply_avg_rope(self, module: nn.Module, mu: torch.Tensor, cov: torch.Tensor, q_len: int): + """ + Apply average RoPE to the mean and covariance matrix of the queries + + Parameters + ---------- + module : nn.Module + The module to apply RoPE to. + mu : torch.Tensor + The mean of the queries. + cov : torch.Tensor + The covariance matrix of the queries. + q_len : int + The length of the queries. + + Returns + ------- + mu : torch.Tensor + The mean of the queries after RoPE. + cov : torch.Tensor + The covariance matrix of the queries after RoPE. + """ position_ids = torch.arange(q_len, q_len + self.n_future_positions).unsqueeze(0).to(mu.device) + head_dim = module.head_dim cos, sin = module.rotary_emb(mu, position_ids) cos, sin = cos[0], sin[0] - Id = torch.eye(head_dim, device=cos.device, dtype=cos.dtype) P = torch.zeros((head_dim, head_dim), device=cos.device, dtype=cos.dtype) P[head_dim // 2 :, : head_dim // 2], P[: head_dim // 2, head_dim // 2 :] = torch.eye(head_dim // 2), -torch.eye( head_dim // 2 ) R = cos.unsqueeze(1) * Id + sin.unsqueeze(1) * P - - # Apply average rotation to the mean and covariance R = R.mean(dim=0).to(mu.device) mu = torch.matmul(mu, R.T) - if self.use_covariance: + if cov is not None: cov = torch.matmul(R, torch.matmul(cov, R.T)) - - # Instead of using the average rotation matrix, we could use a mixture of gaussian statistics to - # estimate mean and covariance. Estimation is better, but end-to-end performance was lower. - # mu = torch.einsum("bhj, fij -> bhfi", mu, R) - # mean_mu = mu.mean(dim=2, keepdim=True) - # if self.use_covariance: - # cov = torch.einsum("fki, bhkl, fjl -> bhfij", R, cov, R) - # cov = cov.mean(dim=2) - # cov += torch.einsum("bhfi, bhfj -> bhji", mu - mean_mu, mu - mean_mu) / self.n_future_positions - # mu = mean_mu.squeeze(2) - return mu, cov def score( diff --git a/kvpress/presses/expected_attention_with_stats.py b/kvpress/presses/expected_attention_with_stats.py new file mode 100644 index 00000000..195f6b8d --- /dev/null +++ b/kvpress/presses/expected_attention_with_stats.py @@ -0,0 +1,295 @@ +# SPDX-FileCopyrightText: Copyright (c) 1993-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 + +import importlib +import os +from contextlib import contextmanager +from dataclasses import dataclass, field +from typing import Optional + +import fire +import torch +from datasets import load_dataset +from huggingface_hub import PyTorchModelHubMixin, get_collection +from torch import nn +from tqdm import tqdm +from transformers import AutoModelForCausalLM, AutoTokenizer, PreTrainedModel + +from kvpress.presses.expected_attention_press import ExpectedAttentionPress + + +@dataclass +class ExpectedAttentionStatsPress(ExpectedAttentionPress): + """ + Expected attention press that automatically loads pre-computed query statistics. + + + Parameters + ---------- + compression_ratio : float, default=0.0 + Fraction of key-value pairs to remove during compression. + n_future_positions : int, default=512 + Number of future positions to consider when computing expected attention. + n_sink : int, default=4 + Number of initial tokens to exclude from compression (sink tokens). + use_covariance : bool, default=True + Whether to include covariance information in expected attention computation. + use_vnorm : bool, default=True + Whether to rescale scores using value vector norms. + epsilon : float, default=0.0 + Small constant added to scores before value norm rescaling. + dataset_name : str, default="kmfoda/booksum" + Dataset used to compute the statistics. + num_samples : int, default=100 + Number of samples used to compute the statistics. + sample_seq_len : int, default=1000 + Sequence length used to compute the statistics. + """ + + # Override parent defaults to enable stats by default + sample_seq_len: int = 1000 + num_samples: int = 100 + dataset_name: str = "kmfoda/booksum" + stats_folder: Optional[str] = None + + mu: torch.Tensor = field(init=False, default=None) # initialized in __post_init_from_model__ + cov: torch.Tensor = field(init=False, default=None) # initialized in __post_init_from_model__ + + def get_query_statistics(self, module: nn.Module, hidden_states: torch.Tensor): + """ + Override the parent method to use the pre-computed query statistics. + """ + q_len = hidden_states.shape[1] + layer_idx = module.layer_idx + mu, cov = self.apply_avg_rope(module, self.mu[layer_idx], self.cov[layer_idx], q_len) # type: ignore + return mu.unsqueeze(0), cov.unsqueeze(0) + + @staticmethod + def available_stats(): + collection = get_collection("alessiodevoto/expectedattentionstats-68b0248d519303713320e2cf") + return [x.item_id for x in collection.items] + + def __post_init_from_model__(self, model): + """ + Automatically load or compute query statistics for the model. + """ + if self.mu is None and self.cov is None: + if self.stats_folder is not None: + stats = ExpectedAttentionStats.from_pretrained(self.stats_folder) + else: + stats = self._maybe_load_stats_from_hub(model) + self.mu = stats.query_mean.data.to(model.device, dtype=model.dtype) + self.cov = stats.query_cov.data.to(model.device, dtype=model.dtype) + + def _maybe_load_stats_from_hub(self, model: PreTrainedModel): + """Load statistics from the Hugging Face Hub.""" + stats_id = ExpectedAttentionStats( + model_name=model.config.name_or_path, + num_layers=model.config.num_hidden_layers, + num_heads=model.config.num_attention_heads, + head_dim=model.config.head_dim, + dataset_name=self.dataset_name, + num_samples=self.num_samples, + sample_seq_len=self.sample_seq_len, + n_sink=self.n_sink, + ).stats_id() + try: + return ExpectedAttentionStats.from_pretrained(stats_id) + except ValueError: + raise ValueError( + f"No statistics found for model {stats_id} on the Hub. Please compute them first. " + "You can do so by running the following code: " + "```" + "python expected_attention_with_stats.py --model_name " + "```" + ) + + @contextmanager + def __call__(self, model): + self.__post_init_from_model__(model) + with super().__call__(model): + yield + + +class ExpectedAttentionStats(torch.nn.Module, PyTorchModelHubMixin): + """ + Module that stores the mean and covariance matrix of the queries, possibly uploaded to the HF hub. + """ + + def __init__( + self, + num_layers: int, + num_heads: int, + head_dim: int, + dataset_name: str, + model_name: str, + num_samples: int, + sample_seq_len: int, + n_sink: int, + ): + super().__init__() + self.query_mean = torch.nn.Parameter(torch.zeros(num_layers, num_heads, head_dim)) + self.query_cov = torch.nn.Parameter(torch.zeros(num_layers, num_heads, head_dim, head_dim)) + self.dataset_name = dataset_name + self.model_name = model_name + self.num_samples = num_samples + self.sample_seq_len = sample_seq_len + self.n_sink = n_sink + + def stats_id(self) -> str: + """Generate the statistics ID for the model and configuration.""" + return f"alessiodevoto/exp_att_stats_{self.model_name.replace('/', '_')}_{self.dataset_name.replace('/', '_')}_{self.num_samples}_{self.sample_seq_len}_{self.n_sink}" # noqa: E501 + + +# The code below is used to collect statistics on a dataset. + + +@contextmanager +def patch_rotary_embedding(model): + """ + A context manager to dynamically patch the `apply_rotary_pos_emb` function + for any supported model architecture. It captures the query states before + rotary embeddings are applied. + + Args: + model (PreTrainedModel): The transformer model instance. + + Yields: + list: A list that will be populated with the captured query tensors. + """ + # Dynamically find the model's specific "modeling" module + try: + module_path = model.__class__.__module__ + modeling_module = importlib.import_module(module_path) + except Exception as e: + raise RuntimeError(f"Failed to import module for {model.__class__.__name__}: {e}") + + # Check for the target function and save the original + target_function = "apply_rotary_pos_emb" + if not hasattr(modeling_module, target_function): + raise AttributeError( + f"Model architecture '{model.config.model_type}' is not supported. " + f"The module '{module_path}' does not contain '{target_function}'." + ) + + original_function = getattr(modeling_module, target_function) + + captured_tensors = [] + + def patched_function(q_embed, k_embed, *args, **kwargs): + # Capture the query tensor before RoPE is applied + captured_tensors.append(q_embed.detach().cpu()) + q_embed, k_embed = original_function(q_embed, k_embed, *args, **kwargs) + return q_embed, k_embed + + # Apply the patch + setattr(modeling_module, target_function, patched_function) + + try: + yield captured_tensors + finally: + setattr(modeling_module, target_function, original_function) + + +@torch.inference_mode() +def collect_queries( + model: PreTrainedModel, + dataset_name: str, + num_samples: int, + sample_seq_len: int, + n_sink: int, + text_column: str = "chapter", +) -> tuple[list[torch.Tensor], torch.Tensor, torch.Tensor]: + """ + Collects query representations from a transformer model using a calibration dataset. + + This function runs the model on a small number of samples from the "kmfoda/booksum" dataset, + capturing the query tensors after rotary positional embeddings are applied. It trims the + input text to a maximum length (`q_len`), skips the first `n_sink` tokens (to avoid outliers), + and returns the collected queries. + + Args: + model (PreTrainedModel): The transformer model instance. + dataset_name (str): Name of the dataset to use for collecting statistics. + num_samples (int): Number of samples to use from the calibration dataset. + q_len (int): Maximum sequence length to consider for each sample. + n_sink (int): Number of initial tokens to exclude from the collected queries. + text_column (str): Name of the column in the dataset containing the text to tokenize. + + Returns: + list or tuple: + collected_queries (list): List of query tensors, each of shape (num_layers, num_heads, seq_len, head_dim) + mean_query (torch.Tensor): Mean query vector for each layer and head. + cov_query (torch.Tensor): Covariance matrix of queries for each layer and head. + """ + + # Load dataset and tokenizer + tokenizer = AutoTokenizer.from_pretrained(model.config.name_or_path) + dataset = load_dataset(dataset_name, split=f"train[:{num_samples}]") + + # Cut to max q_len + dataset = dataset.map(lambda x: {text_column: x[text_column][:sample_seq_len]}) + + collected_queries = [] + for text in tqdm(dataset[text_column], desc="Collecting queries"): + inputs = tokenizer(text, return_tensors="pt").to(model.device) + with patch_rotary_embedding(model) as captured_queries: + model(**inputs) + collected_queries.append(torch.cat(captured_queries, dim=0)[:, :, n_sink:, :]) + + cat_queries = torch.cat(collected_queries, dim=-2) + mean_query = cat_queries.mean(dim=-2) + # compute covariance manually + centered_queries = cat_queries - mean_query.unsqueeze(-2) + N = cat_queries.shape[-2] + cov_query = (centered_queries.transpose(-2, -1) @ centered_queries) / (N - 1) + return collected_queries, mean_query, cov_query + + +def main( + model_name: str = "meta-llama/Llama-3.1-8B-Instruct", + output_path: str = ".", + dataset_name: str = "kmfoda/booksum", + num_samples: int = 100, + sample_seq_len: int = 1000, + n_sink: int = 4, + text_column: str = "chapter", + device_map: str = "auto", +): + """ + Collect query statistics for a transformer model and save them. + + Args: + model_name: Name of the model to collect statistics for + output_path: Directory to save the statistics + dataset_name: Dataset to use for collecting statistics + num_samples: Number of samples to use from the dataset + sample_seq_len: Sequence length for each sample + n_sink: Number of initial tokens to exclude + text_column: Column name containing text in the dataset + device_map: Device mapping for the model + """ + model = AutoModelForCausalLM.from_pretrained(model_name, device_map=device_map, torch_dtype=torch.bfloat16).eval() + + _, mu, cov = collect_queries(model, dataset_name, num_samples, sample_seq_len, n_sink, text_column) + + stats = ExpectedAttentionStats( + num_layers=model.config.num_hidden_layers, + num_heads=model.config.num_attention_heads, + head_dim=model.config.head_dim, + dataset_name=dataset_name, + model_name=model_name, + num_samples=num_samples, + sample_seq_len=sample_seq_len, + n_sink=n_sink, + ) + stats.query_mean.data = mu + stats.query_cov.data = cov + + output_path = os.path.join(output_path, stats.stats_id()) + stats.save_pretrained(output_path) + print(f"Statistics saved to: {output_path}") + + +if __name__ == "__main__": + fire.Fire(main) diff --git a/pyproject.toml b/pyproject.toml index bf9e226c..6a4b5c16 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -22,6 +22,7 @@ dependencies = [ "accelerate>=1.0.0,<2", "requests>=2.32.3,<3", "cachetools>=5.5.2,<6", + "fire>=0.6.0,<0.7", ] [project.optional-dependencies] @@ -30,7 +31,6 @@ eval = [ "nltk>=3.9.1,<4", "tqdm>=4.66.4,<5", "scipy>=1.13.1,<2", - "fire>=0.6.0,<0.7", "bert-score>=0.3.13,<0.4", ] flash-attn = [ diff --git a/tests/default_presses.py b/tests/default_presses.py index fc7635cc..5c954766 100644 --- a/tests/default_presses.py +++ b/tests/default_presses.py @@ -6,6 +6,7 @@ from kvpress import ( DuoAttentionPress, ExpectedAttentionPress, + ExpectedAttentionStatsPress, KeyDiffPress, KnormPress, KVzipPress, @@ -34,6 +35,7 @@ def load_attention_pattern(model): {"cls": TestDuoAttentionPress, "kwargs": [{"head_compression_ratio": 0.2}, {"head_compression_ratio": 0.8}]}, {"cls": KnormPress, "kwargs": [{"compression_ratio": 0.2}, {"compression_ratio": 0.8}]}, {"cls": ExpectedAttentionPress, "kwargs": [{"compression_ratio": 0.2}, {"compression_ratio": 0.8}]}, + {"cls": ExpectedAttentionStatsPress, "kwargs": [{"compression_ratio": 0.2}, {"compression_ratio": 0.8}]}, {"cls": RandomPress, "kwargs": [{"compression_ratio": 0.2}, {"compression_ratio": 0.8}]}, {"cls": StreamingLLMPress, "kwargs": [{"compression_ratio": 0.2}, {"compression_ratio": 0.8}]}, {"cls": QFilterPress, "kwargs": [{"compression_ratio": 0.2}, {"compression_ratio": 0.8}]}, diff --git a/tests/presses/test_ea_with_stats.py b/tests/presses/test_ea_with_stats.py new file mode 100644 index 00000000..d76c6853 --- /dev/null +++ b/tests/presses/test_ea_with_stats.py @@ -0,0 +1,8 @@ +# SPDX-FileCopyrightText: Copyright (c) 1993-2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +from kvpress.presses.expected_attention_with_stats import ExpectedAttentionStats, ExpectedAttentionStatsPress + + +def test_load_stats(): + for stats_id in ExpectedAttentionStatsPress.available_stats(): + ExpectedAttentionStats.from_pretrained(stats_id)