diff --git a/.github/PULL_REQUEST_TEMPLATE.md b/.github/PULL_REQUEST_TEMPLATE.md index dae84775..7b1b416a 100644 --- a/.github/PULL_REQUEST_TEMPLATE.md +++ b/.github/PULL_REQUEST_TEMPLATE.md @@ -4,10 +4,13 @@ Description of your PR. Fixes # (issue) (if applicable) ## Checklist -- Tests are working (`make test`) -- Code is formatted correctly (`make style`, on errors try fix with `make format`) -- Copyright header is included +Before submitting a PR, please make sure: + +- [ ] Tests are working (`make test`) +- [ ] Code is formatted correctly (`make style`, on errors try fix with `make format`) +- [ ] Copyright header is included - [ ] All commits are signed-off using `git commit -s` + - [ ] (new press) `mypress_press.py` is in the `presses` directory - [ ] (new press) `MyPress` is in `__init__.py` - [ ] (new press) `README.md` is updated with a 1 liner about the new press in the Available presses section diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index 113f689a..8d525da3 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -16,6 +16,14 @@ jobs: with: python-version: 3.10.11 + - name: Setup CUDA + uses: Jimver/cuda-toolkit@v0.2.16 + with: + cuda: '12.5.0' + + - name: Set CUDA_HOME + run: echo "CUDA_HOME=/usr/local/cuda" >> $GITHUB_ENV + - name: Install uv uses: astral-sh/setup-uv@v6 with: @@ -25,3 +33,5 @@ jobs: run: uv sync --all-groups - run: make test + env: + HF_TOKEN: ${{ secrets.HF_TOKEN }} diff --git a/Makefile b/Makefile index fbd14aa9..99164e7a 100644 --- a/Makefile +++ b/Makefile @@ -41,9 +41,22 @@ reports: .PHONY: test test: reports + $(UV) pip install optimum-quanto + $(UV) pip install flash-attn PYTHONPATH=. \ $(UV) run pytest \ --cov-report xml:reports/coverage.xml \ --cov=kvpress/ \ --junitxml=./reports/junit.xml \ - tests/ + -v \ + tests/ | tee reports/pytest_output.log + @if grep -q "SKIPPED" reports/pytest_output.log; then \ + echo "Error: Tests were skipped. All tests must run."; \ + grep "SKIPPED" reports/pytest_output.log; \ + exit 1; \ + fi + @if grep -q "FAILED" reports/pytest_output.log; then \ + echo "Error: Some tests failed."; \ + grep "FAILED" reports/pytest_output.log; \ + exit 1; \ + fi diff --git a/pyproject.toml b/pyproject.toml index 7bc1e993..935cf50f 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -91,4 +91,4 @@ disable_error_code = ["attr-defined"] [[tool.mypy.overrides]] module = "kvpress.pipeline" -disable_error_code = ["attr-defined", "assignment", "override"] \ No newline at end of file +disable_error_code = ["attr-defined", "assignment", "override"] diff --git a/tests/fixtures.py b/tests/fixtures.py index c446db4d..9ebceb19 100644 --- a/tests/fixtures.py +++ b/tests/fixtures.py @@ -7,21 +7,29 @@ from transformers import AutoModelForCausalLM, pipeline +def get_device(): + """Helper function that returns the appropriate device (GPU if available, otherwise CPU)""" + return "cuda:0" if torch.cuda.is_available() else "cpu" + + @pytest.fixture(scope="session") def unit_test_model(): - return AutoModelForCausalLM.from_pretrained("MaxJeblick/llama2-0b-unit-test").eval() + model = AutoModelForCausalLM.from_pretrained("MaxJeblick/llama2-0b-unit-test").eval() + return model.to(get_device()) @pytest.fixture(scope="session") def unit_test_model_output_attention(): - return AutoModelForCausalLM.from_pretrained( + model = AutoModelForCausalLM.from_pretrained( "MaxJeblick/llama2-0b-unit-test", attn_implementation="eager", output_attentions=True ).eval() + return model.to(get_device()) @pytest.fixture(scope="session") def danube_500m_model(): - return AutoModelForCausalLM.from_pretrained("h2oai/h2o-danube3-500m-chat").eval() + model = AutoModelForCausalLM.from_pretrained("h2oai/h2o-danube3-500m-chat").eval() + return model.to(get_device()) @pytest.fixture(scope="session") @@ -29,7 +37,7 @@ def kv_press_unit_test_pipeline(): return pipeline( "kv-press-text-generation", model="maxjeblick/llama2-0b-unit-test", - device=0 if torch.cuda.is_available() else -1, + device=get_device(), ) @@ -38,11 +46,46 @@ def kv_press_danube_pipeline(): return pipeline( "kv-press-text-generation", model="h2oai/h2o-danube3-500m-chat", - device=0 if torch.cuda.is_available() else -1, + device=get_device(), ) @pytest.fixture(scope="session") +def kv_press_adaptive_pipeline(): + """Flexible pipeline that uses GPU+flash attention if available, otherwise CPU""" + device = get_device() + ckpt = "meta-llama/Llama-3.2-1B-Instruct" + + # Use flash attention only if GPU is available + model_kwargs = {} + if torch.cuda.is_available(): + model_kwargs["attn_implementation"] = "flash_attention_2" + + pipe = pipeline( + "kv-press-text-generation", + model=ckpt, + device=device, + torch_dtype="auto", + model_kwargs=model_kwargs, + ) + return pipe + + +@pytest.fixture(scope="class") +def kv_press_llama3_1_flash_attn_pipeline(): + device = "cuda:0" + ckpt = "meta-llama/Llama-3.1-8B-Instruct" + attn_implementation = "flash_attention_2" + pipe = pipeline( + "kv-press-text-generation", + model=ckpt, + device=device, + model_kwargs={"attn_implementation": attn_implementation, "torch_dtype": torch.bfloat16}, + ) + return pipe + + +@pytest.fixture(scope="class") def kv_press_llama3_2_flash_attn_pipeline(): device = "cuda:0" ckpt = "meta-llama/Llama-3.2-1B-Instruct" @@ -56,10 +99,10 @@ def kv_press_llama3_2_flash_attn_pipeline(): return pipe -@pytest.fixture(scope="session") -def kv_press_llama3_1_flash_attn_pipeline(): +@pytest.fixture(scope="class") +def kv_press_qwen3_flash_attn_pipeline(): device = "cuda:0" - ckpt = "meta-llama/Llama-3.1-8B-Instruct" + ckpt = "Qwen/Qwen3-4B-Instruct-2507" attn_implementation = "flash_attention_2" pipe = pipeline( "kv-press-text-generation", diff --git a/tests/integration/test_ruler.py b/tests/integration/test_ruler.py index f6975919..43cfe290 100644 --- a/tests/integration/test_ruler.py +++ b/tests/integration/test_ruler.py @@ -6,9 +6,9 @@ import torch from transformers import DynamicCache, QuantoQuantizedCache from transformers.utils import is_flash_attn_2_available, is_optimum_quanto_available - +from kvpress import QFilterPress from tests.default_presses import default_presses -from tests.fixtures import kv_press_llama3_1_flash_attn_pipeline # noqa: F401 +from tests.fixtures import kv_press_llama3_2_flash_attn_pipeline, kv_press_qwen3_flash_attn_pipeline # noqa: F401 @pytest.fixture(scope="session") @@ -18,40 +18,94 @@ def df_ruler(): return df -@pytest.mark.skipif(not torch.cuda.is_available(), reason="GPU is not available") -@pytest.mark.skipif(not is_flash_attn_2_available(), reason="flash_attn is not installed") -@pytest.mark.parametrize("press_dict", default_presses) -@pytest.mark.parametrize("cache", ["dynamic", "quantized"]) -@pytest.mark.parametrize("compression_ratio", [0, 0.1]) -def test_ruler_is_correct( - kv_press_llama3_1_flash_attn_pipeline, df_ruler, press_dict, cache, compression_ratio # noqa: F811 -): - cls = press_dict["cls"] - kwargs = press_dict["kwargs"][0] - press = cls(**kwargs) - if not hasattr(cls, "compression_ratio"): - pytest.skip(reason="Press does not support compression_ratio") - try: - # set compression ratio to a small value for testing - # we don't want to max out compression, but rather test if cache compression works - press.compression_ratio = compression_ratio - except AttributeError: - # pytest.skip(reason="Press does not support setting compression_ratio") - pass - - if cache == "dynamic": - cache = DynamicCache() - elif cache == "quantized" and is_optimum_quanto_available(): - cache = QuantoQuantizedCache(config=kv_press_llama3_1_flash_attn_pipeline.model.config, nbits=4) - elif cache == "quantized" and not is_optimum_quanto_available(): - pytest.skip("Quanto is not installed") - else: - raise ValueError(f"Unknown cache type: {cache}") - - idx = 0 - context = df_ruler.iloc[idx]["context"] - question = df_ruler.iloc[idx]["question"] - true_answer = df_ruler.iloc[idx]["answer"][0] - - pred_answer = kv_press_llama3_1_flash_attn_pipeline(context, question=question, press=press, cache=cache)["answer"] - assert true_answer in pred_answer +class TestRuler: + @pytest.mark.skipif(not torch.cuda.is_available(), reason="GPU is not available") + @pytest.mark.skipif(not is_flash_attn_2_available(), reason="flash_attn is not installed") + @pytest.mark.parametrize("press_dict", default_presses) + @pytest.mark.parametrize("cache", ["dynamic", "quantized"]) + @pytest.mark.parametrize("compression_ratio", [0, 0.1]) + def test_ruler_is_correct( + self, kv_press_qwen3_flash_attn_pipeline, df_ruler, press_dict, cache, compression_ratio # noqa: F811 + ): + cls = press_dict["cls"] + kwargs = press_dict["kwargs"][0] + press = cls(**kwargs) + if not hasattr(cls, "compression_ratio"): + pytest.skip(reason="Press does not support compression_ratio") + try: + # set compression ratio to a small value for testing + # we don't want to max out compression, but rather test if cache compression works + press.compression_ratio = compression_ratio + except AttributeError: + # pytest.skip(reason="Press does not support setting compression_ratio") + pass + + if cache == "dynamic": + cache = DynamicCache() + elif cache == "quantized" and is_optimum_quanto_available(): + cache = QuantoQuantizedCache(config=kv_press_qwen3_flash_attn_pipeline.model.config, nbits=4) + elif cache == "quantized" and not is_optimum_quanto_available(): + pytest.skip("Quanto is not installed") + else: + raise ValueError(f"Unknown cache type: {cache}") + + idx = 6 # qwen model passed idx 6 for all configurations + context = df_ruler.iloc[idx]["context"] + question = df_ruler.iloc[idx]["question"] + true_answer = df_ruler.iloc[idx]["answer"][0] + + if isinstance(press, QFilterPress): + # QFilterPress doesn't support Qwen3 4B. Will be tested in the next test class. + return + else: + pred_answer = kv_press_qwen3_flash_attn_pipeline( + context, + question=question, + press=press, + cache=cache + )["answer"] + assert true_answer in pred_answer + + +class TestRulerForQFilter: + @pytest.mark.skipif(not torch.cuda.is_available(), reason="GPU is not available") + @pytest.mark.skipif(not is_flash_attn_2_available(), reason="flash_attn is not installed") + @pytest.mark.parametrize("cache", ["dynamic", "quantized"]) + @pytest.mark.parametrize("compression_ratio", [0, 0.1]) + def test_ruler_is_correct_for_qfilter( + self, kv_press_llama3_2_flash_attn_pipeline, df_ruler, cache, compression_ratio # noqa: F811 + ): + cls = QFilterPress + kwargs = {"compression_ratio": 0.2} + press = cls(**kwargs) + if not hasattr(cls, "compression_ratio"): + pytest.skip(reason="Press does not support compression_ratio") + try: + # set compression ratio to a small value for testing + # we don't want to max out compression, but rather test if cache compression works + press.compression_ratio = compression_ratio + except AttributeError: + # pytest.skip(reason="Press does not support setting compression_ratio") + pass + + if cache == "dynamic": + cache = DynamicCache() + elif cache == "quantized" and is_optimum_quanto_available(): + cache = QuantoQuantizedCache(config=kv_press_llama3_2_flash_attn_pipeline.model.config, nbits=4) + elif cache == "quantized" and not is_optimum_quanto_available(): + pytest.skip("Quanto is not installed") + else: + raise ValueError(f"Unknown cache type: {cache}") + + idx = 0 + context = df_ruler.iloc[idx]["context"] + question = df_ruler.iloc[idx]["question"] + true_answer = df_ruler.iloc[idx]["answer"][0] + + pred_answer = kv_press_llama3_2_flash_attn_pipeline( + context, + question=question, + press=press, + cache=cache + )["answer"] + assert true_answer in pred_answer diff --git a/tests/presses/test_block_press.py b/tests/presses/test_block_press.py index dcae1af5..00c7301b 100644 --- a/tests/presses/test_block_press.py +++ b/tests/presses/test_block_press.py @@ -33,7 +33,7 @@ def test_block_press_is_streaming_top_k(unit_test_model): # noqa: F811 """ press = HiddenStatesPress(compression_ratio=0.5) generator = torch.Generator().manual_seed(0) - input_ids = torch.randint(0, 1024, (1, 256), generator=generator) + input_ids = torch.randint(0, 1024, (1, 256), generator=generator).to(unit_test_model.device) keys_hash = [] values_hash = [] diff --git a/tests/presses/test_finch_press.py b/tests/presses/test_finch_press.py index ce46c7f7..c579a181 100644 --- a/tests/presses/test_finch_press.py +++ b/tests/presses/test_finch_press.py @@ -16,6 +16,6 @@ def test_finch_press(unit_test_model): # noqa: F811 ]: press.delimiter_token_id = unit_test_model.config.eos_token_id with press(unit_test_model): - input_ids = torch.arange(10, 20) + input_ids = torch.arange(10, 20).to(unit_test_model.device) input_ids[8] = press.delimiter_token_id unit_test_model(input_ids.unsqueeze(0)) diff --git a/tests/presses/test_flash_attention.py b/tests/presses/test_flash_attention.py index 87b42a69..425f9b38 100644 --- a/tests/presses/test_flash_attention.py +++ b/tests/presses/test_flash_attention.py @@ -7,19 +7,20 @@ from transformers.utils import is_flash_attn_2_available from kvpress import KnormPress -from tests.fixtures import kv_press_llama3_1_flash_attn_pipeline # noqa: F401 +from tests.fixtures import kv_press_qwen3_flash_attn_pipeline # noqa: F401 -@pytest.mark.skipif(not torch.cuda.is_available(), reason="GPU is not available") -@pytest.mark.skipif(not is_flash_attn_2_available(), reason="flash_attn is not installed") -def test_fa_works(kv_press_llama3_1_flash_attn_pipeline): # noqa: F811 - # test if fa2 runs, see https://github.com/huggingface/transformers/releases/tag/v4.55.2 - # and https://github.com/NVIDIA/kvpress/pull/115 - model = kv_press_llama3_1_flash_attn_pipeline.model - tok = AutoTokenizer.from_pretrained("h2oai/h2o-danube3-500m-chat") - inputs = tok("Hello, how are you? bla bla how are you? this is some text lala ddd", return_tensors="pt").to( - model.device - ) +class TestFlashAttention: + @pytest.mark.skipif(not torch.cuda.is_available(), reason="GPU is not available") + @pytest.mark.skipif(not is_flash_attn_2_available(), reason="flash_attn is not installed") + def test_fa_works(self, kv_press_qwen3_flash_attn_pipeline): # noqa: F811 + # test if fa2 runs, see https://github.com/huggingface/transformers/releases/tag/v4.55.2 + # and https://github.com/NVIDIA/kvpress/pull/115 + model = kv_press_qwen3_flash_attn_pipeline.model + tok = AutoTokenizer.from_pretrained("Qwen/Qwen3-4B-Instruct-2507") + inputs = tok("Hello, how are you? bla bla how are you? this is some text lala ddd", return_tensors="pt").to( + model.device + ) - with KnormPress(0.8)(model): - model.generate(**inputs, max_new_tokens=10, do_sample=False) + with KnormPress(0.8)(model): + model.generate(**inputs, max_new_tokens=10, do_sample=False) diff --git a/tests/presses/test_head_compression.py b/tests/presses/test_head_compression.py index 83be7dd4..ebf316b1 100644 --- a/tests/presses/test_head_compression.py +++ b/tests/presses/test_head_compression.py @@ -28,7 +28,7 @@ def test_wrapper_head_compression(unit_test_model, wrapper_press, compression_ra p = KnormPress(compression_ratio=compression_ratio) press = wrapper_press(press=p) with press(unit_test_model): - input_ids = torch.randint(0, 1024, (1, 128)) + input_ids = torch.randint(0, 1024, (1, 128)).to(unit_test_model.device) unit_test_model(input_ids, past_key_values=DynamicCache()).past_key_values assert unit_test_model.model.layers[0].self_attn.masked_key_indices is not None @@ -47,7 +47,7 @@ def test_wrapper_head_compression(unit_test_model, wrapper_press, compression_ra def test_head_compression(unit_test_model, press, compression_ratio, layerwise): # noqa: F811 press = KVzipPress(compression_ratio=compression_ratio, layerwise=layerwise) with press(unit_test_model): - input_ids = torch.randint(0, 1024, (1, 128)) + input_ids = torch.randint(0, 1024, (1, 128)).to(unit_test_model.device) unit_test_model(input_ids, past_key_values=DynamicCache()).past_key_values assert unit_test_model.model.layers[0].self_attn.masked_key_indices is not None diff --git a/tests/presses/test_presses.py b/tests/presses/test_presses.py index d70d4328..4fc88701 100644 --- a/tests/presses/test_presses.py +++ b/tests/presses/test_presses.py @@ -31,7 +31,7 @@ def test_composed_press(unit_test_model): # noqa: F811 press2 = ThinKPress(key_channel_compression_ratio=0.5, window_size=2) composed_press = ComposedPress([press1, press2]) with composed_press(unit_test_model): - input_ids = unit_test_model.dummy_inputs["input_ids"] + input_ids = unit_test_model.dummy_inputs["input_ids"].to(unit_test_model.device) unit_test_model(input_ids, past_key_values=DynamicCache()).past_key_values @@ -40,7 +40,7 @@ def test_chunk_press(unit_test_model): # noqa: F811 for chunk_length in [2, 4, 8, 128]: composed_press = ChunkPress(press=press, chunk_length=chunk_length) with composed_press(unit_test_model): - input_ids = torch.randint(0, 1024, (1, 256)) + input_ids = torch.randint(0, 1024, (1, 256), device=unit_test_model.device) cache = DynamicCache() unit_test_model(input_ids, past_key_values=cache).past_key_values assert cache.get_seq_length() == 128 @@ -51,7 +51,7 @@ def test_chunkkv_press(unit_test_model): # noqa: F811 for chunk_length in [2, 4, 8, 128]: composed_press = ChunkKVPress(press=press, chunk_length=chunk_length) with composed_press(unit_test_model): - input_ids = torch.randint(0, 1024, (1, 256)) + input_ids = torch.randint(0, 1024, (1, 256), device=unit_test_model.device) cache = DynamicCache() unit_test_model(input_ids, past_key_values=cache).past_key_values assert cache.get_seq_length() == 128 @@ -84,7 +84,7 @@ def test_presses_run(unit_test_model, press_dict, wrapper_press): # noqa: F811 if hasattr(press, "__post_init_from_model__"): press.__post_init_from_model__(unit_test_model) with press(unit_test_model): - input_ids = torch.randint(0, 1024, (1, 128)) + input_ids = torch.randint(0, 1024, (1, 128), device=unit_test_model.device) unit_test_model(input_ids, past_key_values=DynamicCache()).past_key_values # Check that the press has a compression_ratio attribute assert hasattr(press, "compression_ratio") @@ -95,7 +95,9 @@ def test_presses_run_observed_attention(unit_test_model_output_attention): # no for compresion_ratio in [0.2, 0.8]: press = cls(compression_ratio=compresion_ratio) with press(unit_test_model_output_attention): - input_ids = unit_test_model_output_attention.dummy_inputs["input_ids"] + input_ids = unit_test_model_output_attention.dummy_inputs["input_ids"].to( + unit_test_model_output_attention.device + ) unit_test_model_output_attention(input_ids, past_key_values=DynamicCache()).past_key_values @@ -126,7 +128,7 @@ def test_presses_keep_highest_score(unit_test_model): # noqa: F811 for compresion_ratio in [0.0, 0.2, 0.4, 0.6, 0.8]: press = StoreKnormPress(compression_ratio=compresion_ratio) with press(unit_test_model): - input_ids = torch.randint(0, 3_000, (5, 256)) + input_ids = torch.randint(0, 3_000, (5, 256), device=unit_test_model.device) past_key_values = unit_test_model(input_ids, past_key_values=DynamicCache()).past_key_values keys = [layer.keys for layer in past_key_values.layers] diff --git a/tests/presses/test_wrappers.py b/tests/presses/test_wrappers.py index 7383dc36..5b46c755 100644 --- a/tests/presses/test_wrappers.py +++ b/tests/presses/test_wrappers.py @@ -14,7 +14,7 @@ def test_composed_press_qfilter_without_post_init(unit_test_model): # noqa: F81 composed_press = ComposedPress([press1, press2]) with pytest.raises(ValueError, match="post_init_from_model"): with composed_press(unit_test_model): - input_ids = unit_test_model.dummy_inputs["input_ids"] + input_ids = unit_test_model.dummy_inputs["input_ids"].to(unit_test_model.device) unit_test_model(input_ids, past_key_values=DynamicCache()).past_key_values @@ -28,7 +28,7 @@ def test_composed_press_duo_attention_without_post_init(unit_test_model): # noq composed_press = ComposedPress([press1, press2]) with pytest.raises(ValueError, match="post_init_from_model"): with composed_press(unit_test_model): - input_ids = unit_test_model.dummy_inputs["input_ids"] + input_ids = unit_test_model.dummy_inputs["input_ids"].to(unit_test_model.device) unit_test_model(input_ids, past_key_values=DynamicCache()).past_key_values @@ -45,7 +45,7 @@ def test_composed_qfilter_press_with_post_init(unit_test_model): # noqa: F811 composed_press = ComposedPress([press1, press2]) with composed_press(unit_test_model): - input_ids = unit_test_model.dummy_inputs["input_ids"] + input_ids = unit_test_model.dummy_inputs["input_ids"].to(unit_test_model.device) with pytest.raises(RuntimeError, match="The size of tensor"): unit_test_model(input_ids, past_key_values=DynamicCache()).past_key_values @@ -63,5 +63,5 @@ def test_composed_duo_attention_press_with_post_init(unit_test_model): # noqa: composed_press = ComposedPress([press1, press2]) with composed_press(unit_test_model): - input_ids = unit_test_model.dummy_inputs["input_ids"] + input_ids = unit_test_model.dummy_inputs["input_ids"].to(unit_test_model.device) unit_test_model(input_ids, past_key_values=DynamicCache()).past_key_values diff --git a/tests/test_attention_patch.py b/tests/test_attention_patch.py index 99733f38..5163acb2 100644 --- a/tests/test_attention_patch.py +++ b/tests/test_attention_patch.py @@ -7,7 +7,8 @@ def test_search_hyperplane(): + device = "cuda:0" if torch.cuda.is_available() else "cpu" bsz, seq_len, head_dim = 50, 500, 128 - X = torch.rand(bsz, seq_len, head_dim) + X = torch.rand(bsz, seq_len, head_dim, device=device) Y = search_hyperplane(X) assert torch.exp(torch.bmm(X, Y.unsqueeze(-1))).max() == 0 diff --git a/tests/test_per_layer_compression_press.py b/tests/test_per_layer_compression_press.py index 753a6972..6a57d481 100644 --- a/tests/test_per_layer_compression_press.py +++ b/tests/test_per_layer_compression_press.py @@ -13,7 +13,7 @@ def test_per_layer_compression_press(unit_test_model): # noqa: F811 press = PerLayerCompressionPress(compression_ratios=[0.1, 1], press=KnormPress()) with press(unit_test_model): - input_ids = torch.randint(0, 3_000, (5, 256)) + input_ids = torch.randint(0, 3_000, (5, 256), device=unit_test_model.device) past_key_values = unit_test_model(input_ids, past_key_values=DynamicCache()).past_key_values assert past_key_values.layers[0].keys.shape == torch.Size([5, 2, 230, 6]) diff --git a/tests/test_pipeline.py b/tests/test_pipeline.py index 030d83cb..c3e6f0b4 100644 --- a/tests/test_pipeline.py +++ b/tests/test_pipeline.py @@ -43,33 +43,34 @@ def test_pipeline_with_cache(kv_press_unit_test_pipeline): # noqa: F811 assert isinstance(answers[0], str) -@pytest.mark.skipif(not torch.cuda.is_available(), reason="GPU is not available") -@pytest.mark.skipif(not is_flash_attn_2_available(), reason="flash_attn is not installed") -@pytest.mark.parametrize("compression_ratio", [0.0, 0.2]) -def test_pipeline_fa2(compression_ratio, kv_press_llama3_2_flash_attn_pipeline): # noqa: F811 - context = "This is a test article. It was written on 2022-01-01." - questions = ["Repeat the last sentence"] - press = ExpectedAttentionPress(compression_ratio=compression_ratio) - cache = DynamicCache() - answers = kv_press_llama3_2_flash_attn_pipeline( - context, questions=questions, press=press, cache=cache, max_new_tokens=6 - )["answers"] - - assert len(answers) == 1 - assert isinstance(answers[0], str) - - kv_press_llama3_2_flash_attn_pipeline.model.set_attn_implementation("sdpa") - press = ExpectedAttentionPress(compression_ratio=compression_ratio) - cache = DynamicCache() - answers_sdpa = kv_press_llama3_2_flash_attn_pipeline( - context, questions=questions, press=press, cache=cache, max_new_tokens=6 - )["answers"] - kv_press_llama3_2_flash_attn_pipeline.model.set_attn_implementation("flash_attention_2") - - assert ( - answers_sdpa[0] == answers[0] - ), f"Answers from SDPA and Flash Attention 2 should be the same. \n{answers_sdpa[0]}\n{answers[0]}" - assert "This is a test" in answers[0], f"The answer should contain the context sentence, but got {answers[0]}." +class TestPipelineFA2: + @pytest.mark.skipif(not torch.cuda.is_available(), reason="GPU is not available") + @pytest.mark.skipif(not is_flash_attn_2_available(), reason="flash_attn is not installed") + @pytest.mark.parametrize("compression_ratio", [0.0, 0.2]) + def test_pipeline_fa2(self, kv_press_llama3_2_flash_attn_pipeline, compression_ratio): # noqa: F811 + context = "This is a test article. It was written on 2022-01-01." + questions = ["Repeat the last sentence"] + press = ExpectedAttentionPress(compression_ratio=compression_ratio) + cache = DynamicCache() + answers = kv_press_llama3_2_flash_attn_pipeline( + context, questions=questions, press=press, cache=cache, max_new_tokens=6 + )["answers"] + + assert len(answers) == 1 + assert isinstance(answers[0], str) + + kv_press_llama3_2_flash_attn_pipeline.model.set_attn_implementation("sdpa") + press = ExpectedAttentionPress(compression_ratio=compression_ratio) + cache = DynamicCache() + answers_sdpa = kv_press_llama3_2_flash_attn_pipeline( + context, questions=questions, press=press, cache=cache, max_new_tokens=6 + )["answers"] + kv_press_llama3_2_flash_attn_pipeline.model.set_attn_implementation("flash_attention_2") + + assert ( + answers_sdpa[0] == answers[0] + ), f"Answers from SDPA and Flash Attention 2 should be the same. \n{answers_sdpa[0]}\n{answers[0]}" + assert "This is a test" in answers[0], f"The answer should contain the context sentence, but got {answers[0]}." @pytest.mark.parametrize("question", ["When was this article written?", ""]) @@ -92,7 +93,6 @@ def test_pipeline_no_press_works(kv_press_unit_test_pipeline, caplog): # noqa: kv_press_unit_test_pipeline(context, question=question) -@pytest.mark.skipif(not torch.cuda.is_available(), reason="GPU is not available") def test_pipeline_answer_is_correct(danube_500m_model, caplog): # noqa: F811 with caplog.at_level(logging.DEBUG): answers = generate_answer(danube_500m_model) @@ -142,7 +142,7 @@ def test_pipeline_context_cache_is_invariant(unit_test_model): # noqa: F811 model = unit_test_model questions = ["When was this article written?"] tokenizer = AutoTokenizer.from_pretrained(model.config.name_or_path) - device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu") + device = model.device compression_pipeline = KVPressTextGenerationPipeline(model=model, tokenizer=tokenizer, device=device) input_ids_question = tokenizer(questions[0], return_tensors="pt", add_special_tokens=False)["input_ids"].to(device) @@ -162,14 +162,12 @@ def test_pipeline_context_cache_is_invariant(unit_test_model): # noqa: F811 def generate_answer(model): - device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu") - model.to(device) + device = model.device context = "This is a test article. It was written on 2022-01-01." questions = ["When was this article written?", "When was this article written?"] press = ExpectedAttentionPress(compression_ratio=0.4) tokenizer = AutoTokenizer.from_pretrained(model.config.name_or_path) - answers = KVPressTextGenerationPipeline(model=model, tokenizer=tokenizer)( + answers = KVPressTextGenerationPipeline(model=model, tokenizer=tokenizer, device=device)( context, questions=questions, press=press )["answers"] - model.to(torch.device("cpu")) return answers diff --git a/tests/test_press_call.py b/tests/test_press_call.py index d6307a8b..4a0b4100 100644 --- a/tests/test_press_call.py +++ b/tests/test_press_call.py @@ -23,7 +23,7 @@ def test_context_manager_applies_compression(unit_test_model): # noqa: F811 press = KnormPress(compression_ratio=0.2) with press(unit_test_model): - input_ids = unit_test_model.dummy_inputs["input_ids"] + input_ids = unit_test_model.dummy_inputs["input_ids"].to(unit_test_model.device) past_key_values = unit_test_model(input_ids, past_key_values=DynamicCache()).past_key_values seq_len = input_ids.shape[-1] @@ -32,6 +32,7 @@ def test_context_manager_applies_compression(unit_test_model): # noqa: F811 assert key.shape[2] == int(seq_len * 0.8) == past_key_values.get_seq_length() assert values.shape[2] == int(seq_len * 0.8) == past_key_values.get_seq_length() + input_ids = unit_test_model.dummy_inputs["input_ids"].to(unit_test_model.device) past_key_values = unit_test_model(input_ids, past_key_values=DynamicCache()).past_key_values for key, values in past_key_values: