diff --git a/kvpress/presses/decoding_press.py b/kvpress/presses/decoding_press.py index 56d4b6b1..34eb1c40 100644 --- a/kvpress/presses/decoding_press.py +++ b/kvpress/presses/decoding_press.py @@ -3,10 +3,12 @@ import logging from collections import defaultdict +from contextlib import contextmanager from dataclasses import dataclass import torch import torch.nn as nn +from transformers import PreTrainedModel from transformers.cache_utils import QuantizedCache from kvpress.presses.adakv_press import AdaKVPress @@ -179,6 +181,14 @@ def reset(self): self.hidden_states_buffer = defaultdict(list) self.layer_step_counts = defaultdict(int) + @contextmanager + def __call__(self, model: PreTrainedModel): + try: + with super().__call__(model): + yield + finally: + self.reset() + def _find_target_compression_ratio(self, q_len: int, target_tokens: int) -> float: """ Find the compression ratio that results in exactly target_tokens after int() rounding.