From d8884e6f6a294fc8f1a325665d86a07603d43864 Mon Sep 17 00:00:00 2001 From: LeiWang1999 Date: Fri, 5 Jul 2024 08:54:26 +0000 Subject: [PATCH 01/17] Refactor BatchMatMulEmitter and BatchMatMulSelector for improved readability and maintainability --- bitblas/ops/impl/base.py | 16 +++ bitblas/ops/impl/batch_matmul_impl.py | 166 ++++++++++++++++---------- 2 files changed, 119 insertions(+), 63 deletions(-) create mode 100644 bitblas/ops/impl/base.py diff --git a/bitblas/ops/impl/base.py b/bitblas/ops/impl/base.py new file mode 100644 index 000000000..6d510f7da --- /dev/null +++ b/bitblas/ops/impl/base.py @@ -0,0 +1,16 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. +from abc import ABC, abstractmethod + +# TODO: Refactor all the tir script implementations to use this base class +# Abstract base class for TIR script emitters +class TIRScriptEmitter(ABC): + @abstractmethod + def emit(self): + raise NotImplementedError + +# Abstract base class for TIR script selectors +class TIRScriptSelector(ABC): + @abstractmethod + def select(self): + raise NotImplementedError diff --git a/bitblas/ops/impl/batch_matmul_impl.py b/bitblas/ops/impl/batch_matmul_impl.py index 09b536afa..75449ea4b 100644 --- a/bitblas/ops/impl/batch_matmul_impl.py +++ b/bitblas/ops/impl/batch_matmul_impl.py @@ -4,63 +4,117 @@ from bitblas import tvm from tvm import te from bitblas.ops.operator import TransformKind +from .base import TIRScriptEmitter, TIRScriptSelector +from bitblas import tvm +from tvm import te +from bitblas.ops.operator import TransformKind +class BatchMatMulEmitter(TIRScriptEmitter): + def __init__( + self, + batch, + M, + N, + K, + in_dtype="float16", + out_dtype="float16", + accum_dtype="float16", + with_bias=False, + layout="nt", + ): + self.batch = batch + self.M = self._validate_dimension(M, "M") + self.N = self._validate_dimension(N, "N") + self.K = self._validate_dimension(K, "K") + self.in_dtype = in_dtype + self.out_dtype = out_dtype + self.accum_dtype = accum_dtype + self.with_bias = with_bias + self.layout = layout + self._validate_layout() + + @staticmethod + def _validate_dimension(dim, name): + if not isinstance(dim, int): + return tvm.te.var(name.lower()) + return dim -def matmul_nt( - Batch, - M, - N, - K, - in_dtype="float16", - out_dtype="float16", - accum_dtype="float16", - with_bias=False, -): - if not isinstance(M, int): - M = tvm.te.var("m") - A = te.placeholder((Batch, M, K), name="A", dtype=in_dtype) - B = te.placeholder((Batch, N, K), name="B", dtype=in_dtype) - Bias = te.placeholder((N,), name="Bias", dtype=in_dtype) + def _validate_layout(self): + if self.layout not in ["nn", "nt"]: + raise ValueError(f"Unsupported layout: {self.layout}") + if self.layout == "nn": + raise ValueError("Currently only support layout=nt") - # Describe the matrix multiplication in TE - k = te.reduce_axis((0, K), name="k") - C = te.compute( - (Batch, M, N), - lambda b, i, j: te.sum( - A[b, i, k].astype(accum_dtype) * B[b, j, k].astype(accum_dtype), axis=k), - name="C", - ) - last_output = C - if accum_dtype != out_dtype: - D = te.compute((Batch, M, N), lambda b, i, j: C[b, i, j].astype(out_dtype), name="D") - last_output = D + def _create_placeholders(self): + A = te.placeholder((self.batch, self.M, self.K), name="A", dtype=self.in_dtype) + B = te.placeholder((self.batch, self.N, self.K), name="B", dtype=self.in_dtype) + Bias = te.placeholder((self.N,), name="Bias", dtype=self.in_dtype) if self.with_bias else None + return A, B, Bias - if with_bias: - E = te.compute((Batch, M, N), lambda b, i, j: last_output[b, i, j] + Bias[j], name="E") - last_output = E + def _compute_matmul(self, A, B): + k = te.reduce_axis((0, self.K), name="k") + C = te.compute( + (self.batch, self.M, self.N), + lambda b, i, j: te.sum( + A[b, i, k].astype(self.accum_dtype) * B[b, j, k].astype(self.accum_dtype), axis=k), + name="C", + ) + return C - args = [A, B, Bias, last_output] if with_bias else [A, B, last_output] + def _apply_bias(self, C, Bias): + if self.with_bias: + return te.compute((self.batch, self.M, self.N), lambda b, i, j: C[b, i, j] + Bias[j], name="E") + return C - func = te.create_prim_func(args) + def _convert_dtype(self, tensor): + if self.accum_dtype != self.out_dtype: + return te.compute((self.batch, self.M, self.N), lambda b, i, j: tensor[b, i, j].astype(self.out_dtype), name="D") + return tensor - return tvm.IRModule.from_expr(func) + def emit(self): + A, B, Bias = self._create_placeholders() + C = self._compute_matmul(A, B) + last_output = self._convert_dtype(C) + if self.with_bias: + last_output = self._apply_bias(last_output, Bias) + args = [A, B, Bias, last_output] if self.with_bias else [A, B, last_output] + func = te.create_prim_func(args) + return tvm.IRModule.from_expr(func) -def matmul( - Batch, - M, - N, - K, - in_dtype="float16", - out_dtype="float16", - accum_dtype="float16", - with_bias=False, - layout="nt", -): - if layout == "nn": - raise ValueError("Currently only support layout=nt") - return matmul_nt(Batch, M, N, K, in_dtype, out_dtype, accum_dtype, with_bias) +class BatchMatMulSelector(TIRScriptSelector): + def __init__(self, propagate_a: TransformKind = TransformKind.NonTransform, propagate_b: TransformKind = TransformKind.NonTransform): + self.propagate_a = propagate_a + self.propagate_b = propagate_b + + def select( + self, + batch=1, + M=None, + N=16384, + K=16384, + in_dtype="float16", + out_dtype="float16", + accum_dtype="float16", + with_bias=False, + layout="nt", + ): + if layout == "nn": + if self.propagate_a or self.propagate_b: + raise ValueError("Currently only support propagate_a=False and propagate_b=False for layout=nn") + return BatchMatMulEmitter(batch, M, N, K, in_dtype, out_dtype, accum_dtype, with_bias, layout).emit() + elif layout == "nt": + if self.propagate_a and self.propagate_b: + raise ValueError("Currently only support propagate_a or propagate_b for layout=nt") + elif self.propagate_a: + raise ValueError("Currently only support propagate_a=False for layout=nt") + elif self.propagate_b: + raise ValueError("Currently only support propagate_b=False for layout=nt") + else: + return BatchMatMulEmitter(batch, M, N, K, in_dtype, out_dtype, accum_dtype, with_bias, layout).emit() + else: + raise ValueError(f"Unsupported layout: {layout}") def select_implementation( Batch=1, @@ -75,19 +129,5 @@ def select_implementation( propagate_a: TransformKind = TransformKind.NonTransform, propagate_b: TransformKind = TransformKind.NonTransform, ): - if layout == "nn": - if propagate_a or propagate_b: - raise ValueError( - "Currently only support propagate_a=False and propagate_b=False for layout=nn") - return matmul(M, N, K, in_dtype, out_dtype, accum_dtype, with_bias, layout) - elif layout == "nt": - if propagate_a and propagate_b: - raise ValueError("Currently only support propagate_a or propagate_b for layout=nt") - elif propagate_a: - raise ValueError("Currently only support propagate_a=False for layout=nt") - elif propagate_b: - raise ValueError("Currently only support propagate_b=False for layout=nt") - else: - return matmul(Batch, M, N, K, in_dtype, out_dtype, accum_dtype, with_bias, layout) - else: - raise ValueError(f"Unsupported layout: {layout}") + selector = BatchMatMulSelector(propagate_a, propagate_b) + return selector.select(Batch, M, N, K, in_dtype, out_dtype, accum_dtype, with_bias, layout) From fc84173f22d2f4867a8e6413117b5cd8e830ab27 Mon Sep 17 00:00:00 2001 From: LeiWang1999 Date: Fri, 5 Jul 2024 08:57:43 +0000 Subject: [PATCH 02/17] Refactor import statements for improved readability and maintainability --- bitblas/ops/impl/__init__.py | 2 +- bitblas/ops/impl/base.py | 4 ++++ bitblas/ops/impl/batch_matmul_impl.py | 33 ++++++++++++++++++--------- 3 files changed, 27 insertions(+), 12 deletions(-) diff --git a/bitblas/ops/impl/__init__.py b/bitblas/ops/impl/__init__.py index a254dc7fb..8a9bbd2a5 100644 --- a/bitblas/ops/impl/__init__.py +++ b/bitblas/ops/impl/__init__.py @@ -1,3 +1,3 @@ # Copyright (c) Microsoft Corporation. # Licensed under the MIT License. -from .lop3_permutate_impl import tir_interleave_weight +from .lop3_permutate_impl import tir_interleave_weight # noqa: F401 diff --git a/bitblas/ops/impl/base.py b/bitblas/ops/impl/base.py index 6d510f7da..4a67987be 100644 --- a/bitblas/ops/impl/base.py +++ b/bitblas/ops/impl/base.py @@ -2,15 +2,19 @@ # Licensed under the MIT License. from abc import ABC, abstractmethod + # TODO: Refactor all the tir script implementations to use this base class # Abstract base class for TIR script emitters class TIRScriptEmitter(ABC): + @abstractmethod def emit(self): raise NotImplementedError + # Abstract base class for TIR script selectors class TIRScriptSelector(ABC): + @abstractmethod def select(self): raise NotImplementedError diff --git a/bitblas/ops/impl/batch_matmul_impl.py b/bitblas/ops/impl/batch_matmul_impl.py index 75449ea4b..3904f36e6 100644 --- a/bitblas/ops/impl/batch_matmul_impl.py +++ b/bitblas/ops/impl/batch_matmul_impl.py @@ -5,11 +5,10 @@ from tvm import te from bitblas.ops.operator import TransformKind from .base import TIRScriptEmitter, TIRScriptSelector -from bitblas import tvm -from tvm import te -from bitblas.ops.operator import TransformKind + class BatchMatMulEmitter(TIRScriptEmitter): + def __init__( self, batch, @@ -32,7 +31,7 @@ def __init__( self.with_bias = with_bias self.layout = layout self._validate_layout() - + @staticmethod def _validate_dimension(dim, name): if not isinstance(dim, int): @@ -48,7 +47,8 @@ def _validate_layout(self): def _create_placeholders(self): A = te.placeholder((self.batch, self.M, self.K), name="A", dtype=self.in_dtype) B = te.placeholder((self.batch, self.N, self.K), name="B", dtype=self.in_dtype) - Bias = te.placeholder((self.N,), name="Bias", dtype=self.in_dtype) if self.with_bias else None + Bias = te.placeholder( + (self.N,), name="Bias", dtype=self.in_dtype) if self.with_bias else None return A, B, Bias def _compute_matmul(self, A, B): @@ -63,12 +63,16 @@ def _compute_matmul(self, A, B): def _apply_bias(self, C, Bias): if self.with_bias: - return te.compute((self.batch, self.M, self.N), lambda b, i, j: C[b, i, j] + Bias[j], name="E") + return te.compute((self.batch, self.M, self.N), + lambda b, i, j: C[b, i, j] + Bias[j], + name="E") return C def _convert_dtype(self, tensor): if self.accum_dtype != self.out_dtype: - return te.compute((self.batch, self.M, self.N), lambda b, i, j: tensor[b, i, j].astype(self.out_dtype), name="D") + return te.compute((self.batch, self.M, self.N), + lambda b, i, j: tensor[b, i, j].astype(self.out_dtype), + name="D") return tensor def emit(self): @@ -84,7 +88,10 @@ def emit(self): class BatchMatMulSelector(TIRScriptSelector): - def __init__(self, propagate_a: TransformKind = TransformKind.NonTransform, propagate_b: TransformKind = TransformKind.NonTransform): + + def __init__(self, + propagate_a: TransformKind = TransformKind.NonTransform, + propagate_b: TransformKind = TransformKind.NonTransform): self.propagate_a = propagate_a self.propagate_b = propagate_b @@ -102,8 +109,10 @@ def select( ): if layout == "nn": if self.propagate_a or self.propagate_b: - raise ValueError("Currently only support propagate_a=False and propagate_b=False for layout=nn") - return BatchMatMulEmitter(batch, M, N, K, in_dtype, out_dtype, accum_dtype, with_bias, layout).emit() + raise ValueError( + "Currently only support propagate_a=False and propagate_b=False for layout=nn") + return BatchMatMulEmitter(batch, M, N, K, in_dtype, out_dtype, accum_dtype, with_bias, + layout).emit() elif layout == "nt": if self.propagate_a and self.propagate_b: raise ValueError("Currently only support propagate_a or propagate_b for layout=nt") @@ -112,10 +121,12 @@ def select( elif self.propagate_b: raise ValueError("Currently only support propagate_b=False for layout=nt") else: - return BatchMatMulEmitter(batch, M, N, K, in_dtype, out_dtype, accum_dtype, with_bias, layout).emit() + return BatchMatMulEmitter(batch, M, N, K, in_dtype, out_dtype, accum_dtype, + with_bias, layout).emit() else: raise ValueError(f"Unsupported layout: {layout}") + def select_implementation( Batch=1, M=None, From 02f64de6cf2d338c092dcf29ec55b69804fda892 Mon Sep 17 00:00:00 2001 From: LeiWang1999 Date: Fri, 5 Jul 2024 08:58:06 +0000 Subject: [PATCH 03/17] Refactor import statements for improved readability and maintainability --- bitblas/ops/impl/__init__.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/bitblas/ops/impl/__init__.py b/bitblas/ops/impl/__init__.py index 8a9bbd2a5..67e49b2ae 100644 --- a/bitblas/ops/impl/__init__.py +++ b/bitblas/ops/impl/__init__.py @@ -1,3 +1,3 @@ # Copyright (c) Microsoft Corporation. # Licensed under the MIT License. -from .lop3_permutate_impl import tir_interleave_weight # noqa: F401 +from .lop3_permutate_impl import tir_interleave_weight # noqa: F401 From 397eee6141599e84b509594bb99a0531e409c266 Mon Sep 17 00:00:00 2001 From: LeiWang1999 Date: Fri, 5 Jul 2024 16:25:47 +0000 Subject: [PATCH 04/17] disable failure email for ci --- .github/workflows/ci.yml | 11 ++++++++++- 1 file changed, 10 insertions(+), 1 deletion(-) diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index ceb69fcc7..1fbdf19dd 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -64,4 +64,13 @@ jobs: run: | source bitblas_ci/bin/activate cd testing/python - python -m pytest \ No newline at end of file + python -m pytest + + # Control notifications + notify: + runs-on: self-hosted + needs: [format-check, build-test] + if: failure() + steps: + - name: Notification + run: echo "Jobs failed, but no email will be sent." From 20f6ad1e7ca4e6e1ca9e13ad7c1bbc8c430a8e51 Mon Sep 17 00:00:00 2001 From: LeiWang1999 Date: Sat, 6 Jul 2024 03:23:50 +0000 Subject: [PATCH 05/17] remove email notifications. --- .github/workflows/ci.yml | 9 --------- 1 file changed, 9 deletions(-) diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 1fbdf19dd..511b95833 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -65,12 +65,3 @@ jobs: source bitblas_ci/bin/activate cd testing/python python -m pytest - - # Control notifications - notify: - runs-on: self-hosted - needs: [format-check, build-test] - if: failure() - steps: - - name: Notification - run: echo "Jobs failed, but no email will be sent." From b93c39431c803e22b12f71b555939785da36b96a Mon Sep 17 00:00:00 2001 From: LeiWang1999 Date: Sat, 6 Jul 2024 03:25:05 +0000 Subject: [PATCH 06/17] move relax pass from testing to mlc_llm --- .../mlc_llm}/test_weight_only_transform.py | 0 1 file changed, 0 insertions(+), 0 deletions(-) rename {testing/python/transform => integration/mlc_llm}/test_weight_only_transform.py (100%) diff --git a/testing/python/transform/test_weight_only_transform.py b/integration/mlc_llm/test_weight_only_transform.py similarity index 100% rename from testing/python/transform/test_weight_only_transform.py rename to integration/mlc_llm/test_weight_only_transform.py From 257693a7c3cb3083aac144182f58d38bfe3bcfdd Mon Sep 17 00:00:00 2001 From: LeiWang1999 Date: Sat, 6 Jul 2024 05:51:01 +0000 Subject: [PATCH 07/17] Refactor scripts with se check_eual_ref_scripts_with_emitter function --- bitblas/ops/impl/matmul_dequantize_impl.py | 224 ++++++++++++++---- .../operators/test_tir_script_emitter.py | 52 +++- 2 files changed, 216 insertions(+), 60 deletions(-) diff --git a/bitblas/ops/impl/matmul_dequantize_impl.py b/bitblas/ops/impl/matmul_dequantize_impl.py index 1ed6b3404..e69e8fcfb 100644 --- a/bitblas/ops/impl/matmul_dequantize_impl.py +++ b/bitblas/ops/impl/matmul_dequantize_impl.py @@ -15,8 +15,10 @@ _tir_packed_to_unsigned_convert_with_zeros, ) + # TODO: The following code should be refactored. class MatMulNTDequantizeEmitter: + def __init__( self, M, @@ -52,8 +54,8 @@ def __init__( self.fast_decoding = fast_decoding self.with_bias = with_bias self.zeros_mode = zeros_mode - self.propagate_a = propagate_a - self.propagate_b = propagate_b + self.propagate_a = self._legalize_transform_kind(propagate_a) + self.propagate_b = self._legalize_transform_kind(propagate_b) self._validate_bit() self._validate_layout() @@ -69,62 +71,169 @@ def _validate_bit(self): raise ValueError(f"Unsupported bit: {self.bit}") def _validate_layout(self): - if self.layout not in ["nt"]: - raise ValueError(f"Unsupported layout: {self.layout}") + # TODO: extend the dequantize operators into General Layout + pass + + def _legalize_group_size(self): + if self.group_size == -1: + self.group_size = self.K + + def _legalize_transform_kind(self, propagate): + if propagate is None: + return TransformKind.NonTransform + if isinstance(propagate, bool): + return (TransformKind.IntraWarpTransform if propagate else TransformKind.NonTransform) + elif isinstance(propagate, int): + return TransformKind(propagate) def _create_placeholders(self): - storage_nbit = int("".join(c for c in self.storage_dtype if c.isdigit())) - n_float_per_elem = storage_nbit // self.bit - - A = te.placeholder((self.M, self.K), name="A", dtype=self.in_dtype) - B = te.placeholder((self.N, self.K // storage_nbit * self.bit), name="B", dtype=self.storage_dtype) - LUT = te.placeholder((1 << self.bit,), name="LUT", dtype=self.in_dtype) - Scale = te.placeholder((self.N, self.K // self.group_size), name="Scale", dtype=self.in_dtype) - Zeros = te.placeholder((self.N, self.K // self.group_size), name="Zeros", dtype=self.in_dtype) - QZeros = te.placeholder(((self.K // self.group_size), self.N // storage_nbit * self.bit), + storage_dtype = self.storage_dtype + storage_nbit = int("".join(c for c in storage_dtype if c.isdigit())) + in_dtype = self.in_dtype + bit = self.bit + l = r = 16 # noqa: E741 + if in_dtype in ["int8", "e4m3_float8", "e5m2_float8"]: + l, r = 16, 32 # noqa: E741 + + A = te.placeholder((self.M, self.K), name="A", dtype=in_dtype) + B = te.placeholder((self.N, self.K // storage_nbit * bit), + name="B", + dtype=storage_dtype) + if self.propagate_a: + A = te.placeholder((self.M // l, self.K // r, l, r), name="A", dtype=in_dtype) + if self.propagate_b: + target_dtype = DataType(in_dtype) + scaling_factor = 1 + if bit > 0 and bit < target_dtype.bits: + scaling_factor = ((target_dtype.bits // bit) * DataType(storage_dtype).bits // target_dtype.bits) + qr = r * bit // storage_nbit + B = te.placeholder((self.N // l, (self.K // scaling_factor) // qr, l, qr), name="B", dtype=storage_dtype) + + LUT = te.placeholder((1 << bit,), name="LUT", dtype=in_dtype) + Scale = te.placeholder((self.N, self.K // self.group_size), name="Scale", dtype=in_dtype) + Zeros = te.placeholder((self.N, self.K // self.group_size), name="Zeros", dtype=in_dtype) + QZeros = te.placeholder(((self.K // self.group_size), self.N // storage_nbit * bit), name="QZeros", dtype=self.storage_dtype) - Bias = te.placeholder((self.N,), name="Bias", dtype=self.in_dtype) - return A, B, LUT, Scale, Zeros, QZeros, Bias, storage_nbit, n_float_per_elem + Bias = te.placeholder((self.N,), name="Bias", dtype=in_dtype) + return A, B, LUT, Scale, Zeros, QZeros, Bias + + def _propagate_input(self, tensor, transform_kind=TransformKind.NonTransform, matrix_name="A"): + if transform_kind == TransformKind.NonTransform: + return tensor + in_dtype = self.in_dtype + l = r = 16 # noqa: E741 + if in_dtype in ["int8", "e4m3_float8", "e5m2_float8"]: + l, r = 16, 32 # noqa: E741 + _, inversed_index_map = get_propagate_map( + trans=False, dtype=in_dtype, matrix_name=matrix_name) + + def fcompute(i, j): + warp_i, warp_j = i % l, j % r + spatial_args = i // l, j // r + if transform_kind >= TransformKind.IntraWarpTransform: + warp_i, warp_j = inversed_index_map.map_indices([warp_i, warp_j]) + new_index = (*spatial_args, warp_i, warp_j) + return tensor[new_index] + + return te.compute( + (self.M, self.K), + fcompute, + name=f"{matrix_name}_reindex", + ) + + def _propagage_weight(self, tensor, transform_kind=TransformKind.NonTransform, matrix_name="B"): + if transform_kind == TransformKind.NonTransform: + return tensor + in_dtype = self.in_dtype + bit = self.bit + storage_dtype = self.storage_dtype + storage_nbit = int("".join(c for c in self.storage_dtype if c.isdigit())) + + l = r = 16 # noqa: E741 + if in_dtype in ["int8", "e4m3_float8", "e5m2_float8"]: + l, r = 16, 32 # noqa: E741 + _, inversed_index_map = get_propagate_map( + trans=True, dtype=in_dtype, matrix_name=matrix_name) + target_dtype = DataType(in_dtype) + scaling_factor = 1 + if bit > 0 and bit < target_dtype.bits: + scaling_factor = ((target_dtype.bits // bit) * DataType(storage_dtype).bits // + target_dtype.bits) + initial_indices = inversed_index_map.initial_indices + scaling_final_indices = inversed_index_map.map_indices( + initial_indices[:-1] + [initial_indices[-1] * scaling_factor]) + scaling_final_indices = scaling_final_indices[:-1] + [ + scaling_final_indices[-1] // scaling_factor + ] + inversed_index_map = IndexMap( + initial_indices, + scaling_final_indices, + None, + ) + + qr = r * bit // storage_nbit - def _decode_func(self, B, LUT, Scale, Zeros, QZeros, storage_nbit, n_float_per_elem): - w = None + def fcompute(i, j): + warp_i, warp_j = i % l, j % qr + spatial_args = i // l, j // qr + if transform_kind >= TransformKind.IntraWarpTransform: + warp_i, warp_j = inversed_index_map.map_indices([warp_i, warp_j]) + new_index = (*spatial_args, warp_i, warp_j) + return tensor[new_index] + + return te.compute( + (self.N, self.K // storage_nbit * bit), + fcompute, + name=f"{matrix_name}_reindex", + ) + + def _decode_func(self, B, LUT, Scale, Zeros, QZeros): + bit = self.bit + in_dtype = self.in_dtype + storage_dtype = self.storage_dtype + storage_nbit = int("".join(c for c in storage_dtype if c.isdigit())) + storage_type = str("".join(c for c in storage_dtype if not c.isdigit())) + n_float_per_elem = storage_nbit // bit + + # TODO: Move the decode function into a more general place def decode(n, k): + w = None if self.with_zeros and self.zeros_mode == "quantized": - qzeros_dequantize = _tir_packed_to_unsigned_convert(self.storage_dtype, storage_nbit)( - self.bit, + qzeros_dequantize = _tir_packed_to_unsigned_convert(storage_type, storage_nbit)( + bit, QZeros[k, n // n_float_per_elem], n % n_float_per_elem, dtype=self.storage_dtype, ) - w = _tir_packed_to_unsigned_convert_with_zeros(self.storage_dtype, storage_nbit)( - self.bit, + w = _tir_packed_to_unsigned_convert_with_zeros(storage_type, storage_nbit)( + bit, B[n, k // n_float_per_elem], k % n_float_per_elem, qzeros_dequantize, - dtype=self.in_dtype, + dtype=in_dtype, ) elif self.source_format == "uint": - if self.bit == 8: - w = B[n, k].astype(self.in_dtype) - w = _tir_packed_to_unsigned_convert(self.storage_dtype, storage_nbit)( - self.bit, B[n, k // n_float_per_elem], k % n_float_per_elem, dtype=self.in_dtype) + if bit == 8: + w = B[n, k].astype(in_dtype) + w = _tir_packed_to_unsigned_convert(storage_type, storage_nbit)( + bit, B[n, k // n_float_per_elem], k % n_float_per_elem, dtype=in_dtype) elif self.source_format == "int": - if self.bit == 1: - w = _tir_packed_int_to_int_convert(self.storage_dtype, storage_nbit)( - self.bit, B[n, k // n_float_per_elem], k % n_float_per_elem, dtype=self.in_dtype) - if self.bit == 8: - w = B[n, k].astype(self.in_dtype) - w = _tir_packed_to_signed_convert(self.storage_dtype, storage_nbit)( - self.bit, B[n, k // n_float_per_elem], k % n_float_per_elem, dtype=self.in_dtype) + if bit == 1: + w = _tir_packed_int_to_int_convert(storage_type, storage_nbit)( + bit, B[n, k // n_float_per_elem], k % n_float_per_elem, dtype=in_dtype) + if bit == 8: + w = B[n, k].astype(in_dtype) + w = _tir_packed_to_signed_convert(storage_type, storage_nbit)( + bit, B[n, k // n_float_per_elem], k % n_float_per_elem, dtype=in_dtype) elif self.source_format == "fp": w = _tir_u32_to_f4_to_f16( - self.bit, B[n, k // n_float_per_elem], k % n_float_per_elem, dtype=self.in_dtype) + bit, B[n, k // n_float_per_elem], k % n_float_per_elem, dtype=in_dtype) elif self.source_format == "fp_e4m3": - w = _tir_u8_to_f8_e4m3_to_f16(self.bit, B[n, k], dtype=self.in_dtype) + w = _tir_u8_to_f8_e4m3_to_f16(bit, B[n, k], dtype=in_dtype) elif self.source_format == "nf": - index = _tir_packed_to_unsigned_convert(self.storage_dtype, storage_nbit)( - self.bit, + index = _tir_packed_to_unsigned_convert(storage_type, storage_nbit)( + bit, B[n, k // n_float_per_elem], k % n_float_per_elem, dtype="int32", @@ -132,7 +241,9 @@ def decode(n, k): w = LUT[index] else: raise ValueError(f"Unsupported source_format: {self.source_format}") - + + assert w is not None, "w is None" + group_size = self.group_size zeros_mode = self.zeros_mode @@ -167,7 +278,9 @@ def _compute_matmul(self, A, B_decode): def _convert_dtype(self, tensor): if self.accum_dtype != self.out_dtype: - return te.compute((self.M, self.N), lambda i, j: tensor[i, j].astype(self.out_dtype), name="D") + return te.compute((self.M, self.N), + lambda i, j: tensor[i, j].astype(self.out_dtype), + name="D") return tensor def _apply_bias(self, tensor, Bias): @@ -176,9 +289,12 @@ def _apply_bias(self, tensor, Bias): return tensor def emit(self): - A, B, LUT, Scale, Zeros, QZeros, Bias, storage_nbit, n_float_per_elem = self._create_placeholders() - B_decode = self._decode_func(B, LUT, Scale, Zeros, QZeros, storage_nbit, n_float_per_elem) - C = self._compute_matmul(A, B_decode) + A, B, LUT, Scale, Zeros, QZeros, Bias = self._create_placeholders() + A_reindex = self._propagate_input(A, self.propagate_a, "A") + B_reindex = self._propagage_weight(B, self.propagate_b, "B") + + B_decode = self._decode_func(B_reindex, LUT, Scale, Zeros, QZeros) + C = self._compute_matmul(A_reindex, B_decode) D = self._convert_dtype(C) last_output = self._apply_bias(D, Bias) @@ -212,8 +328,13 @@ def emit(self): } }, ) + if self.propagate_a: + func = func.with_attr("input_transform_kind", self.propagate_a.value) + if self.propagate_b: + func = func.with_attr("weight_transform_kind", self.propagate_b.value) return tvm.IRModule.from_expr(func) + def matmul_nt_dequantize_b( M, N, @@ -335,9 +456,12 @@ def decode_func(n, k): A[i, k].astype(accum_dtype) * B_decode[j, k].astype(accum_dtype), axis=k), name="C", ) - D = te.compute((M, N), lambda i, j: C[i, j].astype(out_dtype), name="D") + + last_output = C + if accum_dtype != out_dtype: + D = te.compute((M, N), lambda i, j: C[i, j].astype(out_dtype), name="D") + last_output = D args = [A, B] - last_output = D if source_format == "nf": args.append(LUT) if with_scaling: @@ -517,9 +641,11 @@ def decode_func(n, k): A[i, k].astype(accum_dtype) * B_decode[j, k].astype(accum_dtype), axis=k), name="C", ) - D = te.compute((M, N), lambda i, j: C[i, j].astype(out_dtype), name="D") + last_output = C + if accum_dtype != out_dtype: + D = te.compute((M, N), lambda i, j: C[i, j].astype(out_dtype), name="D") + last_output = D args = [A, B] - last_output = D if source_format == "nf": args.append(LUT) if with_scaling: @@ -715,9 +841,11 @@ def decode_func(n, k): ), name="C", ) - D = te.compute((M, N), lambda i, j: C[i, j].astype(out_dtype), name="D") + last_output = C + if accum_dtype != out_dtype: + D = te.compute((M, N), lambda i, j: C[i, j].astype(out_dtype), name="D") + last_output = D args = [A, B] - last_output = D if source_format == "nf": args.append(LUT) if with_scaling: diff --git a/testing/python/operators/test_tir_script_emitter.py b/testing/python/operators/test_tir_script_emitter.py index cec56b473..fcfa7d9af 100644 --- a/testing/python/operators/test_tir_script_emitter.py +++ b/testing/python/operators/test_tir_script_emitter.py @@ -1,18 +1,13 @@ # Copyright (c) Microsoft Corporation. # Licensed under the MIT License. -from bitblas.ops.impl.matmul_dequantize_impl import ( - MatMulNTDequantizeEmitter, - matmul_nt_dequantize_b, - matmul_nt_dequantize_b_propagate_b, - matmul_nt_dequantize_b_propagate_a_propagate_b, -) from bitblas import tvm import logging from bitblas import set_log_level set_log_level(logging.DEBUG) -def compare_tir_scripts_and_emitter( + +def check_eual_ref_scripts_with_emitter( M, N, K, @@ -28,8 +23,26 @@ def compare_tir_scripts_and_emitter( fast_decoding, with_bias, zeros_mode, + propagate_a, + propagate_b, ): - tir_script_func = matmul_nt_dequantize_b( + from bitblas.ops.impl.matmul_dequantize_impl import ( + MatMulNTDequantizeEmitter, + matmul_nt_dequantize_b, + matmul_nt_dequantize_b_propagate_b, + matmul_nt_dequantize_b_propagate_a_propagate_b, + ) + func = None + if propagate_a and propagate_b: + func = matmul_nt_dequantize_b_propagate_a_propagate_b + elif propagate_b: + func = matmul_nt_dequantize_b_propagate_b + else: + func = matmul_nt_dequantize_b + + assert func is not None, "No function found for the given configuration" + + ref_func = func( M, N, K, @@ -46,8 +59,8 @@ def compare_tir_scripts_and_emitter( with_bias, zeros_mode, ) - - emitter_func = MatMulNTDequantizeEmitter( + + emit_func = MatMulNTDequantizeEmitter( M, N, K, @@ -63,6 +76,21 @@ def compare_tir_scripts_and_emitter( fast_decoding, with_bias, zeros_mode, + propagate_a=propagate_a, + propagate_b=propagate_b, ).emit() - - tvm.ir.assert_structural_equal(tir_script_func, emitter_func) + + tvm.ir.assert_structural_equal(ref_func, emit_func) + + +def test_check_eual_ref_scripts_with_emitter(): + check_eual_ref_scripts_with_emitter(1, 16384, 16384, "float16", "float16", "float16", 4, "int8", "nf", True, False, -1, False, False, "original", False, False) + check_eual_ref_scripts_with_emitter(16384, 16384, 16384, "float16", "float16", "float16", 4, "int8", "nf", True, False, -1, False, False, "original", False, False) + check_eual_ref_scripts_with_emitter(1, 16384, 16384, "float16", "float16", "float16", 4, "int8", "uint", True, False, -1, False, False, "original", False, False) + check_eual_ref_scripts_with_emitter(1, 16384, 16384, "float16", "float16", "float16", 4, "int8", "uint", True, False, -1, False, False, "original", False, False) + check_eual_ref_scripts_with_emitter(1, 16384, 16384, "float16", "float16", "float16", 4, "int8", "uint", True, False, -1, False, False, "original", False, True) + check_eual_ref_scripts_with_emitter(1, 16384, 16384, "float16", "float16", "float16", 4, "int8", "uint", True, False, -1, False, False, "original", False, True) + check_eual_ref_scripts_with_emitter(1024, 1024, 1024, "float16", "float16", "float16", 4, "int8", "uint", True, False, -1, False, False, "original", True, True) + +if __name__ == "__main__": + test_check_eual_ref_scripts_with_emitter() From 9bb7f49a968d4c71dbbc12121b4b7cb8258b2136 Mon Sep 17 00:00:00 2001 From: LeiWang1999 Date: Sat, 6 Jul 2024 05:51:15 +0000 Subject: [PATCH 08/17] Lint Fix --- bitblas/ops/impl/matmul_dequantize_impl.py | 13 +++++---- .../operators/test_tir_script_emitter.py | 29 ++++++++++++++----- 2 files changed, 29 insertions(+), 13 deletions(-) diff --git a/bitblas/ops/impl/matmul_dequantize_impl.py b/bitblas/ops/impl/matmul_dequantize_impl.py index e69e8fcfb..7b91764ca 100644 --- a/bitblas/ops/impl/matmul_dequantize_impl.py +++ b/bitblas/ops/impl/matmul_dequantize_impl.py @@ -73,7 +73,7 @@ def _validate_bit(self): def _validate_layout(self): # TODO: extend the dequantize operators into General Layout pass - + def _legalize_group_size(self): if self.group_size == -1: self.group_size = self.K @@ -96,18 +96,19 @@ def _create_placeholders(self): l, r = 16, 32 # noqa: E741 A = te.placeholder((self.M, self.K), name="A", dtype=in_dtype) - B = te.placeholder((self.N, self.K // storage_nbit * bit), - name="B", - dtype=storage_dtype) + B = te.placeholder((self.N, self.K // storage_nbit * bit), name="B", dtype=storage_dtype) if self.propagate_a: A = te.placeholder((self.M // l, self.K // r, l, r), name="A", dtype=in_dtype) if self.propagate_b: target_dtype = DataType(in_dtype) scaling_factor = 1 if bit > 0 and bit < target_dtype.bits: - scaling_factor = ((target_dtype.bits // bit) * DataType(storage_dtype).bits // target_dtype.bits) + scaling_factor = ((target_dtype.bits // bit) * DataType(storage_dtype).bits // + target_dtype.bits) qr = r * bit // storage_nbit - B = te.placeholder((self.N // l, (self.K // scaling_factor) // qr, l, qr), name="B", dtype=storage_dtype) + B = te.placeholder((self.N // l, (self.K // scaling_factor) // qr, l, qr), + name="B", + dtype=storage_dtype) LUT = te.placeholder((1 << bit,), name="LUT", dtype=in_dtype) Scale = te.placeholder((self.N, self.K // self.group_size), name="Scale", dtype=in_dtype) diff --git a/testing/python/operators/test_tir_script_emitter.py b/testing/python/operators/test_tir_script_emitter.py index fcfa7d9af..b2c7a8d4f 100644 --- a/testing/python/operators/test_tir_script_emitter.py +++ b/testing/python/operators/test_tir_script_emitter.py @@ -84,13 +84,28 @@ def check_eual_ref_scripts_with_emitter( def test_check_eual_ref_scripts_with_emitter(): - check_eual_ref_scripts_with_emitter(1, 16384, 16384, "float16", "float16", "float16", 4, "int8", "nf", True, False, -1, False, False, "original", False, False) - check_eual_ref_scripts_with_emitter(16384, 16384, 16384, "float16", "float16", "float16", 4, "int8", "nf", True, False, -1, False, False, "original", False, False) - check_eual_ref_scripts_with_emitter(1, 16384, 16384, "float16", "float16", "float16", 4, "int8", "uint", True, False, -1, False, False, "original", False, False) - check_eual_ref_scripts_with_emitter(1, 16384, 16384, "float16", "float16", "float16", 4, "int8", "uint", True, False, -1, False, False, "original", False, False) - check_eual_ref_scripts_with_emitter(1, 16384, 16384, "float16", "float16", "float16", 4, "int8", "uint", True, False, -1, False, False, "original", False, True) - check_eual_ref_scripts_with_emitter(1, 16384, 16384, "float16", "float16", "float16", 4, "int8", "uint", True, False, -1, False, False, "original", False, True) - check_eual_ref_scripts_with_emitter(1024, 1024, 1024, "float16", "float16", "float16", 4, "int8", "uint", True, False, -1, False, False, "original", True, True) + check_eual_ref_scripts_with_emitter(1, 16384, 16384, "float16", "float16", "float16", 4, "int8", + "nf", True, False, -1, False, False, "original", False, + False) + check_eual_ref_scripts_with_emitter(16384, 16384, 16384, "float16", "float16", "float16", 4, + "int8", "nf", True, False, -1, False, False, "original", + False, False) + check_eual_ref_scripts_with_emitter(1, 16384, 16384, "float16", "float16", "float16", 4, "int8", + "uint", True, False, -1, False, False, "original", False, + False) + check_eual_ref_scripts_with_emitter(1, 16384, 16384, "float16", "float16", "float16", 4, "int8", + "uint", True, False, -1, False, False, "original", False, + False) + check_eual_ref_scripts_with_emitter(1, 16384, 16384, "float16", "float16", "float16", 4, "int8", + "uint", True, False, -1, False, False, "original", False, + True) + check_eual_ref_scripts_with_emitter(1, 16384, 16384, "float16", "float16", "float16", 4, "int8", + "uint", True, False, -1, False, False, "original", False, + True) + check_eual_ref_scripts_with_emitter(1024, 1024, 1024, "float16", "float16", "float16", 4, + "int8", "uint", True, False, -1, False, False, "original", + True, True) + if __name__ == "__main__": test_check_eual_ref_scripts_with_emitter() From 93eb5a5fe4e3eb6242675dd5706358c4121f1672 Mon Sep 17 00:00:00 2001 From: LeiWang1999 Date: Sat, 6 Jul 2024 05:53:50 +0000 Subject: [PATCH 09/17] Refactor scripts with se check_eual_ref_scripts_with_emitter function --- bitblas/ops/impl/matmul_dequantize_impl.py | 198 --------------------- 1 file changed, 198 deletions(-) diff --git a/bitblas/ops/impl/matmul_dequantize_impl.py b/bitblas/ops/impl/matmul_dequantize_impl.py index 1ef14100d..7b91764ca 100644 --- a/bitblas/ops/impl/matmul_dequantize_impl.py +++ b/bitblas/ops/impl/matmul_dequantize_impl.py @@ -15,204 +15,6 @@ _tir_packed_to_unsigned_convert_with_zeros, ) -# TODO: The following code should be refactored. -class MatMulNTDequantizeEmitter: - def __init__( - self, - M, - N, - K, - in_dtype="float16", - out_dtype="float16", - accum_dtype="float16", - bit=4, - storage_dtype="int8", - source_format="uint", - with_scaling=False, - with_zeros=False, - group_size=-1, - fast_decoding=False, - with_bias=False, - zeros_mode="original", - propagate_a: TransformKind = TransformKind.NonTransform, - propagate_b: TransformKind = TransformKind.NonTransform, - ): - self.M = self._validate_dimension(M, "M") - self.N = N - self.K = K - self.in_dtype = in_dtype - self.out_dtype = out_dtype - self.accum_dtype = accum_dtype - self.bit = bit - self.storage_dtype = storage_dtype - self.source_format = source_format - self.with_scaling = with_scaling - self.with_zeros = with_zeros - self.group_size = group_size if group_size != -1 else K - self.fast_decoding = fast_decoding - self.with_bias = with_bias - self.zeros_mode = zeros_mode - self.propagate_a = propagate_a - self.propagate_b = propagate_b - - self._validate_bit() - self._validate_layout() - - @staticmethod - def _validate_dimension(dim, name): - if not isinstance(dim, int): - return tvm.te.var(name.lower()) - return dim - - def _validate_bit(self): - if self.bit not in [1, 2, 4, 8]: - raise ValueError(f"Unsupported bit: {self.bit}") - - def _validate_layout(self): - if self.layout not in ["nt"]: - raise ValueError(f"Unsupported layout: {self.layout}") - - def _create_placeholders(self): - storage_nbit = int("".join(c for c in self.storage_dtype if c.isdigit())) - n_float_per_elem = storage_nbit // self.bit - - A = te.placeholder((self.M, self.K), name="A", dtype=self.in_dtype) - B = te.placeholder((self.N, self.K // storage_nbit * self.bit), name="B", dtype=self.storage_dtype) - LUT = te.placeholder((1 << self.bit,), name="LUT", dtype=self.in_dtype) - Scale = te.placeholder((self.N, self.K // self.group_size), name="Scale", dtype=self.in_dtype) - Zeros = te.placeholder((self.N, self.K // self.group_size), name="Zeros", dtype=self.in_dtype) - QZeros = te.placeholder(((self.K // self.group_size), self.N // storage_nbit * self.bit), - name="QZeros", - dtype=self.storage_dtype) - Bias = te.placeholder((self.N,), name="Bias", dtype=self.in_dtype) - return A, B, LUT, Scale, Zeros, QZeros, Bias, storage_nbit, n_float_per_elem - - def _decode_func(self, B, LUT, Scale, Zeros, QZeros, storage_nbit, n_float_per_elem): - w = None - def decode(n, k): - if self.with_zeros and self.zeros_mode == "quantized": - qzeros_dequantize = _tir_packed_to_unsigned_convert(self.storage_dtype, storage_nbit)( - self.bit, - QZeros[k, n // n_float_per_elem], - n % n_float_per_elem, - dtype=self.storage_dtype, - ) - w = _tir_packed_to_unsigned_convert_with_zeros(self.storage_dtype, storage_nbit)( - self.bit, - B[n, k // n_float_per_elem], - k % n_float_per_elem, - qzeros_dequantize, - dtype=self.in_dtype, - ) - elif self.source_format == "uint": - if self.bit == 8: - w = B[n, k].astype(self.in_dtype) - w = _tir_packed_to_unsigned_convert(self.storage_dtype, storage_nbit)( - self.bit, B[n, k // n_float_per_elem], k % n_float_per_elem, dtype=self.in_dtype) - elif self.source_format == "int": - if self.bit == 1: - w = _tir_packed_int_to_int_convert(self.storage_dtype, storage_nbit)( - self.bit, B[n, k // n_float_per_elem], k % n_float_per_elem, dtype=self.in_dtype) - if self.bit == 8: - w = B[n, k].astype(self.in_dtype) - w = _tir_packed_to_signed_convert(self.storage_dtype, storage_nbit)( - self.bit, B[n, k // n_float_per_elem], k % n_float_per_elem, dtype=self.in_dtype) - elif self.source_format == "fp": - w = _tir_u32_to_f4_to_f16( - self.bit, B[n, k // n_float_per_elem], k % n_float_per_elem, dtype=self.in_dtype) - elif self.source_format == "fp_e4m3": - w = _tir_u8_to_f8_e4m3_to_f16(self.bit, B[n, k], dtype=self.in_dtype) - elif self.source_format == "nf": - index = _tir_packed_to_unsigned_convert(self.storage_dtype, storage_nbit)( - self.bit, - B[n, k // n_float_per_elem], - k % n_float_per_elem, - dtype="int32", - ) - w = LUT[index] - else: - raise ValueError(f"Unsupported source_format: {self.source_format}") - - group_size = self.group_size - zeros_mode = self.zeros_mode - - if not self.with_scaling: - return w - - if not self.with_zeros: - return w * Scale[n, k // group_size] - - if zeros_mode == "original": - w = (w - Zeros[n, k // group_size]) * Scale[n, k // group_size] - elif zeros_mode == "rescale": - w = w * Scale[n, k // group_size] - Zeros[n, k // group_size] - elif zeros_mode == "quantized": - w = w * Scale[n, k // group_size] - else: - raise ValueError("Unsupported zeros_mode: {}".format(zeros_mode)) - - return w - - return te.compute((self.N, self.K), decode, name="B_decode") - - def _compute_matmul(self, A, B_decode): - k = te.reduce_axis((0, self.K), name="k") - C = te.compute( - (self.M, self.N), - lambda i, j: te.sum( - A[i, k].astype(self.accum_dtype) * B_decode[j, k].astype(self.accum_dtype), axis=k), - name="C", - ) - return C - - def _convert_dtype(self, tensor): - if self.accum_dtype != self.out_dtype: - return te.compute((self.M, self.N), lambda i, j: tensor[i, j].astype(self.out_dtype), name="D") - return tensor - - def _apply_bias(self, tensor, Bias): - if self.with_bias: - return te.compute((self.M, self.N), lambda i, j: tensor[i, j] + Bias[j], name="E") - return tensor - - def emit(self): - A, B, LUT, Scale, Zeros, QZeros, Bias, storage_nbit, n_float_per_elem = self._create_placeholders() - B_decode = self._decode_func(B, LUT, Scale, Zeros, QZeros, storage_nbit, n_float_per_elem) - C = self._compute_matmul(A, B_decode) - D = self._convert_dtype(C) - last_output = self._apply_bias(D, Bias) - - args = [A, B] - if self.source_format == "nf": - args.append(LUT) - if self.with_scaling: - args.append(Scale) - if self.with_zeros: - args.append(QZeros if self.zeros_mode == "quantized" else Zeros) - if self.with_bias: - args.append(Bias) - args.append(last_output) - - func = te.create_prim_func(args).with_attr( - "dequantize_info", - { - "B_decode": { - "decode_block": "B_decode", - "fast_decoding": self.fast_decoding, - "source_format": { - "bits": self.bit, - "format": self.source_format, - }, - "storage_dtype": self.storage_dtype, - "target_format": self.in_dtype, - "with_zeros": self.with_zeros, - "zeros_mode": self.zeros_mode, - "with_scaling": self.with_scaling, - "group_size": self.group_size, - } - }, - ) - return tvm.IRModule.from_expr(func) # TODO: The following code should be refactored. class MatMulNTDequantizeEmitter: From aa66a9080d41330ba63f38b76c539c6be0362906 Mon Sep 17 00:00:00 2001 From: LeiWang1999 Date: Sat, 6 Jul 2024 08:03:46 +0000 Subject: [PATCH 10/17] bug fix in test --- bitblas/ops/impl/matmul_dequantize_impl.py | 9 ++-- testing/python/module/test_bitblas_linear.py | 41 ++++++++----------- .../operators/test_general_matmul_ops.py | 2 +- 3 files changed, 22 insertions(+), 30 deletions(-) diff --git a/bitblas/ops/impl/matmul_dequantize_impl.py b/bitblas/ops/impl/matmul_dequantize_impl.py index 7b91764ca..55d672097 100644 --- a/bitblas/ops/impl/matmul_dequantize_impl.py +++ b/bitblas/ops/impl/matmul_dequantize_impl.py @@ -473,8 +473,7 @@ def decode_func(n, k): else: args.append(Zeros) if with_bias: - E = te.compute((M, N), lambda i, j: D[i, j] + Bias[j], name="E") - last_output = E + last_output = te.compute((M, N), lambda i, j: last_output[i, j] + Bias[j], name="E") args.append(Bias) args.append(last_output) @@ -654,8 +653,7 @@ def decode_func(n, k): if with_zeros: args.append(Zeros) if with_bias: - E = te.compute((M, N), lambda i, j: D[i, j] + Bias[j], name="E") - last_output = E + last_output = te.compute((M, N), lambda i, j: last_output[i, j] + Bias[j], name="E") args.append(Bias) args.append(last_output) @@ -854,8 +852,7 @@ def decode_func(n, k): if with_zeros: args.append(Zeros) if with_bias: - E = te.compute((M, N), lambda i, j: D[i, j] + Bias[j], name="E") - last_output = E + last_output = te.compute((M, N), lambda i, j: last_output[i, j] + Bias[j], name="E") args.append(Bias) args.append(last_output) diff --git a/testing/python/module/test_bitblas_linear.py b/testing/python/module/test_bitblas_linear.py index eeaf90475..eee08c93c 100644 --- a/testing/python/module/test_bitblas_linear.py +++ b/testing/python/module/test_bitblas_linear.py @@ -11,16 +11,7 @@ torch.manual_seed(0) bitblas.set_log_level("DEBUG") -@pytest.mark.parametrize( - "m, in_features, out_features, bias", - [ - (1, 1024, 1024, False), - (1, 1024, 1024, True), - (1024, 1024, 1024, True), - ([1, 1024], 1024, 1024, True), - ], -) -def test_correctness_consistent(m, in_features, out_features, bias): +def correctness_consistent(m, in_features, out_features, bias): linear_torch = (nn.Linear(in_features, out_features, bias=bias).to(torch.float16).cuda()) linear_bitblas = BitBLASLinear( in_features, @@ -48,19 +39,13 @@ def test_correctness_consistent(m, in_features, out_features, bias): torch.testing.assert_close(output_torch, output_bitblas, rtol=1e-1, atol=1e-2) -@pytest.mark.parametrize( - "m, in_features, out_features, bias, W_dtype, group_size, with_scaling, with_zeros, zeros_mode", - [ - (1, 1024, 1024, False, "uint4", -1, False, False, None), - (1, 1024, 1024, False, "uint4", -1, False, False, None), - (1024, 1024, 1024, True, "uint4", -1, False, False, None), - (1, 1024, 1024, True, "uint2", -1, True, False, None), - (1, 1024, 1024, True, "uint2", 128, True, True, "original"), - (1024, 1024, 1024, True, "uint2", 128, True, True, "original"), - (1, 1024, 1024, True, "uint2", 128, True, True, "rescale"), - ], -) -def test_correctness_weight_only_dequantize( +def test_correctness_consistent(): + correctness_consistent(1, 1024, 1024, False) + correctness_consistent(1, 1024, 1024, True) + correctness_consistent(1024, 1024, 1024, True) + correctness_consistent([1, 1024], 1024, 1024, True) + +def correctness_weight_only_dequantize( m, in_features, out_features, @@ -169,6 +154,16 @@ def test_correctness_weight_only_dequantize( torch.testing.assert_close(output_bitblas, ref_result, rtol=1e0, atol=1e0) +def test_correctness_weight_only_dequantize(): + correctness_weight_only_dequantize(1, 1024, 1024, False, "uint4", -1, False, False, None) + correctness_weight_only_dequantize(1, 1024, 1024, False, "uint4", -1, False, False, None) + correctness_weight_only_dequantize(1024, 1024, 1024, True, "uint4", -1, False, False, None) + correctness_weight_only_dequantize(1, 1024, 1024, True, "uint2", -1, True, False, None) + correctness_weight_only_dequantize(1, 1024, 1024, True, "uint2", 128, True, True, "original") + correctness_weight_only_dequantize(1024, 1024, 1024, True, "uint2", 128, True, True, "original") + correctness_weight_only_dequantize(1, 1024, 1024, True, "uint2", 128, True, True, "rescale") + + def profile(model, input_data): model = model.cuda() model.eval() diff --git a/testing/python/operators/test_general_matmul_ops.py b/testing/python/operators/test_general_matmul_ops.py index 05e0a45f4..62808e2a7 100644 --- a/testing/python/operators/test_general_matmul_ops.py +++ b/testing/python/operators/test_general_matmul_ops.py @@ -195,7 +195,7 @@ def matmul_torch_forward(M, N, K, A_dtype, W_dtype, accum_dtype, out_dtype, layo if with_bias: permuted_inputs.append(bias) permuted_inputs.append(inputs[2]) - matmul(*permuted_inputs[:2], output=permuted_inputs[-1]) + matmul(*permuted_inputs[:-1], output=permuted_inputs[-1]) if zeros_mode == "rescale": torch.testing.assert_close(permuted_inputs[-1], ref_result, rtol=1e2, atol=1e0) else: From 79b08e415ffe79d7d4320e815cc2f5e603775e57 Mon Sep 17 00:00:00 2001 From: LeiWang1999 Date: Sat, 6 Jul 2024 08:21:08 +0000 Subject: [PATCH 11/17] lint fix. --- testing/python/module/test_bitblas_linear.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/testing/python/module/test_bitblas_linear.py b/testing/python/module/test_bitblas_linear.py index eee08c93c..f329a146e 100644 --- a/testing/python/module/test_bitblas_linear.py +++ b/testing/python/module/test_bitblas_linear.py @@ -6,11 +6,11 @@ import time import numpy as np import torch.nn as nn -import pytest torch.manual_seed(0) bitblas.set_log_level("DEBUG") + def correctness_consistent(m, in_features, out_features, bias): linear_torch = (nn.Linear(in_features, out_features, bias=bias).to(torch.float16).cuda()) linear_bitblas = BitBLASLinear( @@ -45,6 +45,7 @@ def test_correctness_consistent(): correctness_consistent(1024, 1024, 1024, True) correctness_consistent([1, 1024], 1024, 1024, True) + def correctness_weight_only_dequantize( m, in_features, From 86fd0361bb74e21a87a26159022282ca25a4b282 Mon Sep 17 00:00:00 2001 From: LeiWang1999 Date: Sun, 7 Jul 2024 14:03:33 +0000 Subject: [PATCH 12/17] test cuda i4 kernel --- testing/cpp/CMakeLists.txt | 1 + .../cpp/efficient_i4_cuda_impl/CMakeLists.txt | 20 + .../efficient_i4_cuda_impl/efficient_i4.cu | 391 +++++++++ .../cpp/efficient_i4_cuda_impl/i4matmul.hpp | 822 ++++++++++++++++++ .../param_permutate.cpp | 89 ++ 5 files changed, 1323 insertions(+) create mode 100644 testing/cpp/efficient_i4_cuda_impl/CMakeLists.txt create mode 100644 testing/cpp/efficient_i4_cuda_impl/efficient_i4.cu create mode 100644 testing/cpp/efficient_i4_cuda_impl/i4matmul.hpp create mode 100644 testing/cpp/efficient_i4_cuda_impl/param_permutate.cpp diff --git a/testing/cpp/CMakeLists.txt b/testing/cpp/CMakeLists.txt index cf8eb0d3a..b92fa8da7 100644 --- a/testing/cpp/CMakeLists.txt +++ b/testing/cpp/CMakeLists.txt @@ -12,4 +12,5 @@ find_package(GTest REQUIRED) include_directories(${GTEST_INCLUDE_DIRS}) +add_subdirectory(efficient_i4_cuda_impl) add_subdirectory(lop3_type_conversion) diff --git a/testing/cpp/efficient_i4_cuda_impl/CMakeLists.txt b/testing/cpp/efficient_i4_cuda_impl/CMakeLists.txt new file mode 100644 index 000000000..36ffdf548 --- /dev/null +++ b/testing/cpp/efficient_i4_cuda_impl/CMakeLists.txt @@ -0,0 +1,20 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. +function (ADD_CUDA_TEST_EXECUTABLE name) + add_executable(${name} ${name}.cu) + set_target_properties(${name} PROPERTIES CUDA_ARCHITECTURES 80) + # add flags + target_compile_options(${name} PRIVATE --expt-relaxed-constexpr) + set_target_properties(${name} PROPERTIES + CUDA_SEPARABLE_COMPILATION ON) + target_link_libraries(${name} gtest gtest_main) +endfunction(ADD_CUDA_TEST_EXECUTABLE) + +ADD_CUDA_TEST_EXECUTABLE(efficient_i4) + +function (ADD_CPP_TEST_EXECUTABLE name) + add_executable(${name} ${name}.cpp) + target_link_libraries(${name} gtest gtest_main pthread) +endfunction(ADD_CPP_TEST_EXECUTABLE) + +ADD_CPP_TEST_EXECUTABLE(param_permutate) diff --git a/testing/cpp/efficient_i4_cuda_impl/efficient_i4.cu b/testing/cpp/efficient_i4_cuda_impl/efficient_i4.cu new file mode 100644 index 000000000..257f49a31 --- /dev/null +++ b/testing/cpp/efficient_i4_cuda_impl/efficient_i4.cu @@ -0,0 +1,391 @@ +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT License. +#include +#include +#include +#include +#include "i4matmul.hpp" + +#define cudaCheckLastError(ans) \ + { \ + gpuAssert((ans), __FILE__, __LINE__); \ + } +inline void gpuAssert(cudaError_t code, const char *file, int line, bool abort = true) +{ + if (code != cudaSuccess) + { + fprintf(stderr, "GPUassert: %s %s %d\n", cudaGetErrorString(code), file, line); + if (abort) + exit(code); + } +} + +void general_compress(const int8_t *lowbit, int8_t *compressed, const int nbit, const int N, bool isSigned = false) +{ + int zero_point = isSigned ? ((1 << (nbit - 1)) - 1) : 0; + const int nbit_per_byte = 8 / nbit; + + for (int i = 0; i < N / nbit_per_byte; i++) + { + compressed[i] = 0; + for (int j = 0; j < nbit_per_byte; j++) + { + compressed[i] |= ((lowbit[nbit_per_byte * i + j] + zero_point) << (nbit * j)); + } + } +} + + +// Helper function to interleave the perm array +std::vector interleave_perms(const std::vector& perm) { + std::vector interleaved_perm; + std::array interleave = {0, 2, 4, 6, 1, 3, 5, 7}; + + int num_rows = perm.size() / 8; + for (int i = 0; i < num_rows; ++i) { + std::array row; + std::copy(perm.begin() + i * 8, perm.begin() + (i + 1) * 8, row.begin()); + for (int j : interleave) { + interleaved_perm.push_back(row[j]); + } + } + + return interleaved_perm; +} + + +std::tuple, std::vector, std::vector> get_perms() { + std::vector perm; + + for (int i = 0; i < 32; ++i) { + std::vector perm1; + int col = i / 4; + for (int block : {0, 1}) { + for (int row : { + 2 * (i % 4), + 2 * (i % 4) + 1, + 2 * (i % 4 + 4), + 2 * (i % 4 + 4) + 1 + }) { + perm1.push_back(16 * row + col + 8 * block); + } + } + for (int j = 0; j < 4; ++j) { + for (int p : perm1) { + perm.push_back(p + 256 * j); + } + } + } + + // Interleave the perm array + perm = interleave_perms(perm); + + std::vector scale_perm; + for (int i = 0; i < 8; ++i) { + for (int j = 0; j < 8; ++j) { + scale_perm.push_back(i + 8 * j); + } + } + + std::vector scale_perm_single; + for (int i = 0; i < 4; ++i) { + for (int j : {0, 1, 8, 9, 16, 17, 24, 25}) { + scale_perm_single.push_back(2 * i + j); + } + } + + return std::make_tuple(perm, scale_perm, scale_perm_single); +} + +void weight_pre_process(const int8_t *lowbit, int8_t *compressed, const int nbit, const int K, const int N) +{ + int8_t* tmp1 = new int8_t[K * N]; + const int maxq = 15; + auto [perm, scale_perm, scale_perm_single] = get_perms(); + const int tile_size = 16; + // transform the lowbit matrix to the compressed matrix + for (int i = 0; i < (K / tile_size); i += 1) + { + for (int j = 0; j < (N / tile_size); j += 1) + { + for (int k = 0; k < tile_size; k++) + { + for (int l = 0; l < tile_size; l++) + { + int idx_target = i * N * tile_size + j * tile_size * tile_size + k * tile_size + l; + int idx_source = (i * tile_size + k) * N + j * tile_size + l; + tmp1[idx_target] = lowbit[idx_source] + (maxq + 1) / 2; + } + } + } + } + // print the first 10 of tmp2 + printf("tmp1\n"); + for (int i = 0; i < 10; i++) + { + printf("%d ", tmp1[i]); + } + printf(" ... "); + for (int i = K * N - 10; i < K * N; i++) + { + printf("%d ", tmp1[i]); + } + printf("\n"); + // permute the matrix + int32_t* tmp2 = new int32_t[K * N]; + const int perm_size = perm.size(); + for (int i = 0; i < (N * K / perm_size); i++) + { + for (int j = 0; j < perm_size; j++) + { + int idx_target = i * perm_size + j; + int idx_source = i * perm_size + perm[j]; + tmp2[idx_target] = (int32_t)tmp1[idx_source]; + } + } + // print the first 10 of tmp2 + printf("tmp2\n"); + for (int i = 0; i < 10; i++) + { + printf("%d ", tmp2[i]); + } + printf(" ... "); + for (int i = K * N / (32 / nbit) - 10; i < K * N / (32 / nbit); i++) + { + printf("%d ", tmp2[i]); + } + printf("\n"); + // compress + int32_t* tmp3 = new int32_t[K * N / (32 / nbit)]; + // set zero + for (int i = 0; i < K * N / (32 / nbit); i++) + { + tmp3[i] = 0; + } + for (int i = 0; i < (K / tile_size); i++) + { + for (int j = 0; j < (N * tile_size / 8); j++) + { + for (int k = 0; k < 8; k++) + { + int idx_target = i * N * tile_size / 8 + j; + int idx_source = i * N * tile_size + j * 8 + k; + tmp3[idx_target] |= (tmp2[idx_source] << (nbit * (k % 8))); + } + } + } + // print the first 10 of tmp3 + printf("tmp3\n"); + for (int i = 0; i < 10; i++) + { + printf("%d ", tmp3[i]); + } + printf(" ... "); + for (int i = K * N / (32 / nbit) - 10; i < K * N / (32 / nbit); i++) + { + printf("%d ", tmp3[i]); + } + printf("\n"); + // copy tmp3 to compressed + for (int i = 0; i < K * N / (32 / nbit); i++) + { + ((int32_t *)(compressed))[i] = tmp3[i]; + } +} + +void scale_pre_process(const half *scale, half *scale_perm, const int K, const int N, int group_size) +{ + auto [perm, scale_perm_group, scale_perm_single] = get_perms(); + if (group_size == -1) + group_size = K; + if (group_size == K){ + const int perm_size = scale_perm_single.size(); + for (int i = 0; i < (N * K / group_size / perm_size); i++) + { + for (int j = 0; j < perm_size; j++) + { + int idx_target = i * perm_size + j; + int idx_source = i * perm_size + scale_perm_single[j]; + if (idx_target < 10){ + printf("idx_target = %d, idx_source = %d\n", idx_target, idx_source); + } + scale_perm[idx_target] = scale[idx_source]; + } + } + } + else{ + const int perm_size = scale_perm_group.size(); + for (int i = 0; i < (N * K / group_size / perm_size); i++) + { + for (int j = 0; j < perm_size; j++) + { + int idx_target = i * perm_size + j; + int idx_source = i * perm_size + scale_perm_group[j]; + scale_perm[idx_target] = scale[idx_source]; + } + } + } + // print the first 10 of tmp2 + printf("scale_perm\n"); + for (int i = 0; i < 10; i++) + { + printf("%f ", (float)scale_perm[i]); + } + printf(" ... "); + for (int i = K * N / group_size - 10; i < K * N / group_size; i++) + { + printf("%f ", (float)scale_perm[i]); + } +} + +TEST(EfficientI4MatmulTest, GEMVTest) +{ + const int prom_m = 1; + const int prom_n = 256; + const int prom_k = 256; + const int bits = 4; + const int group_size = prom_k; + + half* A = new half[prom_m * prom_k]; + int8_t* B = new int8_t[prom_k * prom_n]; + int8_t* qB_interleave = new int8_t[prom_k * prom_n / (8 / bits)]; + half* C = new half[prom_m * prom_n]; + half* s = new half[prom_n * (prom_k / group_size)]; + half* s_perm = new half[prom_n * (prom_k / group_size)]; + + // Initialize A and B + for (int i = 0; i < prom_m * prom_k; i++) + { + A[i] = __float2half(rand() / (float)RAND_MAX); + } + for (int i = 0; i < prom_k * prom_n; i++) + { + B[i] = rand() % 4 - 2; + } + for (int i = 0; i < prom_k * prom_n / group_size; i++) + { + // s[i] = __float2half(0.1); + s[i] = __float2half(rand() / (float)RAND_MAX); + } + + weight_pre_process(B, qB_interleave, bits, prom_k, prom_n); + // print the first 10 elements and last 10 elements of C + for (int i = 0; i < 10; i++) + { + printf("%d ", B[i]); + } + printf(" ... "); + for (int i = prom_k * prom_n - 10; i < prom_k * prom_n; i++) + { + printf("%d ", B[i]); + } + // print interleave of B + for (int i = 0; i < 10; i++) + { + printf("%d ", qB_interleave[i]); + } + printf(" ... "); + for (int i = prom_k * prom_n / (8 / bits) - 10; i < prom_k * prom_n / (8 / bits); i++) + { + printf("%d ", qB_interleave[i]); + } + printf("\n"); + // print last 10 of qb_interleave + for (int i = prom_k * prom_n / (8 / bits) - 10; i < prom_k * prom_n / (8 / bits); i++) + { + printf("%d ", qB_interleave[i]); + } + printf("\n"); + // print last 10 of B + for (int i = prom_k * prom_n - 10; i < prom_k * prom_n; i++) + { + printf("%d ", B[i]); + } + printf("\n"); + // print last 10 of s + for (int i = prom_n * (prom_k / group_size) - 10; i < prom_n * (prom_k / group_size); i++) + { + printf("%f ", __half2float(s[i])); + } + printf("\n"); + scale_pre_process(s, s_perm, prom_k, prom_n, group_size); + // define cuda variables + float* d_workspace = nullptr; + cudaCheckLastError(cudaMalloc((void**)&d_workspace, prom_n * prom_k * 16 * sizeof(float))); + + half* d_A; + int8_t* d_qB; + half* d_C; + half* d_s; + cudaCheckLastError(cudaMalloc((void**)&d_A, prom_m * prom_k * sizeof(half))); + cudaCheckLastError(cudaMalloc((void**)&d_qB, prom_k * prom_n / (8 / bits) * sizeof(int8_t))); + cudaCheckLastError(cudaMalloc((void**)&d_C, prom_m * prom_n * sizeof(half))); + cudaCheckLastError(cudaMalloc((void**)&d_s, prom_n * (prom_k / group_size) * sizeof(half))); + // copy A and B to device + cudaCheckLastError(cudaMemcpy(d_A, A, prom_m * prom_k * sizeof(half), cudaMemcpyHostToDevice)); + cudaCheckLastError(cudaMemcpy(d_qB, qB_interleave, prom_n * prom_k / (8 / bits) * sizeof(int8_t), cudaMemcpyHostToDevice)); + cudaCheckLastError(cudaMemcpy(d_s, s_perm, prom_n * (prom_k / group_size) * sizeof(half), cudaMemcpyHostToDevice)); + + // allocate workspace + // call the kernel + int ret = marlin_cuda(d_A, d_qB, d_C, d_s, prom_m, prom_n, prom_k, d_workspace, group_size == prom_k? -1: group_size); + printf("ret = %d\n", ret); + + // copy C back to host + cudaCheckLastError(cudaMemcpy(C, d_C, prom_m * prom_n * sizeof(half), cudaMemcpyDeviceToHost)); + // print the first 10 elements and last 10 elements of C + for (int i = 0; i < 10; i++) + { + printf("%f ", __half2float(C[i])); + } + printf(" ... "); + for (int i = prom_m * prom_n - 10; i < prom_m * prom_n; i++) + { + printf("%f ", __half2float(C[i])); + } + printf("\n"); + + // ref calculation + float* ref_C = new float[prom_m * prom_n]; + // zero fill + for (int i = 0; i < prom_m * prom_n; i++) + { + ref_C[i] = __float2half(0.0); + } + // + for (int i = 0; i < prom_m; i++) + { + for (int j = 0; j < prom_n; j++) + { + ref_C[i * prom_n + j] = __float2half(0.0); + for (int k = 0; k < prom_k; k++) + { + ref_C[i * prom_n + j] += float(A[i * prom_k + k]) * (float(B[k * prom_n + j]) * float(s[(k / group_size) * prom_n + j])); + } + } + } + for (int i = 0; i < 10; i++) + { + printf("%f ", __half2float(ref_C[i])); + } + printf(" ... "); + for (int i = prom_m * prom_n - 10; i < prom_m * prom_n; i++) + { + printf("%f ", __half2float(ref_C[i])); + } + printf("\n"); + + // check the result + for (int i = 0; i < prom_m * prom_n; i++) + { + EXPECT_NEAR(__half2float(C[i]), __half2float(ref_C[i]), 1e-1); + } + + // free memory + delete[] A; + delete[] B; + delete[] C; + cudaCheckLastError(cudaFree(d_A)); + cudaCheckLastError(cudaFree(d_qB)); + cudaCheckLastError(cudaFree(d_C)); +} diff --git a/testing/cpp/efficient_i4_cuda_impl/i4matmul.hpp b/testing/cpp/efficient_i4_cuda_impl/i4matmul.hpp new file mode 100644 index 000000000..ae4cef5a2 --- /dev/null +++ b/testing/cpp/efficient_i4_cuda_impl/i4matmul.hpp @@ -0,0 +1,822 @@ +/* + * Copyright (C) Marlin.2024 Elias Frantar (elias.frantar@ist.ac.at) + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + + +#ifndef MARLIN_CUDA_KERNEL_CUH +#define MARLIN_CUDA_KERNEL_CUH + + +#include +#include +#include +#include + + +constexpr int ceildiv(int a, int b) { + return (a + b - 1) / b; +} + +// Instances of `Vec` are used to organize groups of >>registers<<, as needed for instance as inputs to tensor core +// operations. Consequently, all corresponding index accesses must be compile-time constants, which is why we +// extensively use `#pragma unroll` throughout the kernel code to guarantee this. +template +struct Vec { + T elems[n]; + __device__ T& operator[](int i) { + return elems[i]; + } +}; + +using I4 = Vec; + +// Matrix fragments for tensor core instructions; their precise layout is documented here: +// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#matrix-fragments-for-mma-m16n8k16-with-floating-point-type +using FragA = Vec; +using FragB = Vec; +using FragC = Vec; +using FragS = Vec; // quantization scales + +// Predicated asynchronous global->shared copy; used for inputs A where we apply predication to handle batchsizes that +// are not multiples of 16. +__device__ inline void cp_async4_pred(void* smem_ptr, const void* glob_ptr, bool pred = true) { + const int BYTES = 16; + uint32_t smem = static_cast(__cvta_generic_to_shared(smem_ptr)); + asm volatile( + "{\n" + " .reg .pred p;\n" + " setp.ne.b32 p, %0, 0;\n" + " @p cp.async.cg.shared.global [%1], [%2], %3;\n" + "}\n" :: "r"((int) pred), "r"(smem), "l"(glob_ptr), "n"(BYTES) + ); +} + +// Asynchronous global->shared copy with a cache hint indicating that the values may be evicted immediately; used for +// quantized weights B, which are only accessed precisely once and should thus not pollute the L2 cache which we need +// for inputs A and outputs C. +__device__ inline void cp_async4_stream(void* smem_ptr, const void* glob_ptr) { + const int BYTES = 16; + uint32_t smem = static_cast(__cvta_generic_to_shared(smem_ptr)); + asm volatile( + "{\n" + " .reg .b64 p;\n" + " createpolicy.fractional.L2::evict_first.b64 p, 1.0;" + " cp.async.cg.shared.global.L2::cache_hint [%0], [%1], %2, p;\n" + "}\n" :: "r"(smem), "l"(glob_ptr), "n"(BYTES) + ); +} + +// Async copy fence. +__device__ inline void cp_async_fence() { + asm volatile("cp.async.commit_group;\n" ::); +} + +// Wait until at most `n` async copy stages are still pending. +template +__device__ inline void cp_async_wait() { + asm volatile("cp.async.wait_group %0;\n" :: "n"(n)); +} + +// m16n8k16 tensor core mma instruction with fp16 inputs and fp32 output/accumulation. +__device__ inline void mma(const FragA& a_frag, const FragB& frag_b, FragC& frag_c) { + const uint32_t* a = reinterpret_cast(&a_frag); + const uint32_t* b = reinterpret_cast(&frag_b); + float* c = reinterpret_cast(&frag_c); + asm volatile( + "mma.sync.aligned.m16n8k16.row.col.f32.f16.f16.f32 " + "{%0,%1,%2,%3}, {%4,%5,%6,%7}, {%8,%9}, {%10,%11,%12,%13};\n" + : "=f"(c[0]), "=f"(c[1]), "=f"(c[2]), "=f"(c[3]) + : "r"(a[0]), "r"(a[1]), "r"(a[2]), "r"(a[3]), "r"(b[0]), "r"(b[1]), + "f"(c[0]), "f"(c[1]), "f"(c[2]), "f"(c[3]) + ); +} + +// Instruction for loading a full 16x16 matrix fragment of operand A from shared memory, directly in tensor core layout. +__device__ inline void ldsm4(FragA& frag_a, const void* smem_ptr) { + uint32_t* a = reinterpret_cast(&frag_a); + uint32_t smem = static_cast(__cvta_generic_to_shared(smem_ptr)); + asm volatile( + "ldmatrix.sync.aligned.m8n8.x4.shared.b16 {%0,%1,%2,%3}, [%4];\n" + : "=r"(a[0]), "=r"(a[1]), "=r"(a[2]), "=r"(a[3]) : "r"(smem) + ); +} + +// Lookup-table based 3-input logical operation; explicitly used for dequantization as the compiler does not seem to +// automatically recognize it in all cases. +template +__device__ inline int lop3(int a, int b, int c) { + int res; + asm volatile( + "lop3.b32 %0, %1, %2, %3, %4;\n" + : "=r"(res) : "r"(a), "r"(b), "r"(c), "n"(lut) + ); + return res; +} + +// Efficiently dequantize an int32 value into a full B-fragment of 4 fp16 values. +// We mostly follow the strategy in the link below, with some small changes: +// https://github.com/NVIDIA/FasterTransformer/blob/main/src/fastertransformer/cutlass_extensions/include/cutlass_extensions/interleaved_numeric_conversion.h +__device__ inline FragB dequant(int q) { + const int LO = 0x000f000f; + const int HI = 0x00f000f0; + const int EX = 0x64006400; + // Guarantee that the `(a & b) | c` operations are LOP3s. + int lo = lop3<(0xf0 & 0xcc) | 0xaa>(q, LO, EX); + int hi = lop3<(0xf0 & 0xcc) | 0xaa>(q, HI, EX); + // We want signed int4 outputs, hence we fuse the `-8` symmetric zero point directly into `SUB` and `ADD`. + const int SUB = 0x64086408; + const int MUL = 0x2c002c00; + const int ADD = 0xd480d480; + FragB frag_b; + frag_b[0] = __hsub2( + *reinterpret_cast(&lo), + *reinterpret_cast(&SUB) + ); + frag_b[1] = __hfma2( + *reinterpret_cast(&hi), + *reinterpret_cast(&MUL), *reinterpret_cast(&ADD) + ); + return frag_b; +} + +// Multiply dequantized values by the corresponding quantization scale; used only for grouped quantization. +__device__ inline void scale(FragB& frag_b, FragS& frag_s, int i) { + half2 s = __half2half2(reinterpret_cast<__half*>(&frag_s)[i]); + frag_b[0] = __hmul2(frag_b[0], s); + frag_b[1] = __hmul2(frag_b[1], s); +} + +// Wait until barrier reaches `count`, then lock for current threadblock. +__device__ inline void barrier_acquire(int* lock, int count) { + if (threadIdx.x == 0) { + int state = -1; + do + // Guarantee that subsequent writes by this threadblock will be visible globally. + asm volatile ("ld.global.acquire.gpu.b32 %0, [%1];\n" : "=r"(state) : "l"(lock)); + while (state != count); + } + __syncthreads(); +} + +// Release barrier and increment visitation count. +__device__ inline void barrier_release(int* lock, bool reset = false) { + __syncthreads(); + if (threadIdx.x == 0) { + if (reset) { + lock[0] = 0; + return; + } + int val = 1; + // Make sure that all writes since acquiring this barrier are visible globally, while releasing the barrier. + asm volatile ("fence.acq_rel.gpu;\n"); + asm volatile ("red.relaxed.gpu.global.add.s32 [%0], %1;\n" : : "l"(lock), "r"(val)); + } +} + + +template < + const int threads, // number of threads in a threadblock + const int thread_m_blocks, // number of 16x16 blocks in the m dimension (batchsize) of the threadblock + const int thread_n_blocks, // same for n dimension (output) + const int thread_k_blocks, // same for k dimension (reduction) + const int stages, // number of stages for the async global->shared fetch pipeline + const int group_blocks = -1 // number of consecutive 16x16 blocks with a separate quantization scale +> +__global__ void Marlin( + const int4* __restrict__ A, // fp16 input matrix of shape mxk + const int4* __restrict__ B, // 4bit quantized weight matrix of shape kxn + int4* __restrict__ C, // fp16 output buffer of shape mxn + const int4* __restrict__ s, // fp16 quantization scales of shape (k/groupsize)xn + int prob_m, // batch dimension m + int prob_n, // output dimension n + int prob_k, // reduction dimension k + int* locks // extra global storage for barrier synchronization +) { + // Each threadblock processes one "stripe" of the B matrix with (roughly) the same size, which might involve multiple + // column "slices" (of width 16 * `thread_n_blocks`). Stripes are defined as shown in the 3x3 matrix 5 SM example: + // 0 1 3 + // 0 2 3 + // 1 2 4 + // While this kind of partitioning makes things somewhat more complicated, it ensures good utilization of all SMs + // for many kinds of shape and GPU configurations, while requiring as few slow global cross-threadblock reductions as + // possible. + + // For larger GEMMs we run multiple batchsize 64 versions in parallel for a better partitioning with less reductions + int parallel = 1; + if (prob_m > 16 * thread_m_blocks) { + parallel = prob_m / (16 * thread_m_blocks); + prob_m = 16 * thread_m_blocks; + } + + int k_tiles = prob_k / 16 / thread_k_blocks; + int n_tiles = prob_n / 16 / thread_n_blocks; + int iters = ceildiv(k_tiles * n_tiles * parallel, gridDim.x); + // Ensure that the number of tiles in each stripe is a multiple of the groupsize; this avoids an annoying special case + // where a stripe starts in the middle of group. + if (group_blocks != -1) + iters = (group_blocks / thread_k_blocks) * ceildiv(iters, (group_blocks / thread_k_blocks)); + + int slice_row = (iters * blockIdx.x) % k_tiles; + int slice_col_par = (iters * blockIdx.x) / k_tiles; + int slice_col = slice_col_par; + int slice_iters; // number of threadblock tiles in the current slice + int slice_count = 0; // total number of active threadblocks in the current slice + int slice_idx; // index of threadblock in current slice; numbered bottom to top + + // We can easily implement parallel problem execution by just remapping indices and advancing global pointers + if (slice_col_par >= n_tiles) { + A += (slice_col_par / n_tiles) * 16 * thread_m_blocks * prob_k / 8; + C += (slice_col_par / n_tiles) * 16 * thread_m_blocks * prob_n / 8; + locks += (slice_col_par / n_tiles) * n_tiles; + slice_col = slice_col_par % n_tiles; + } + + // Compute all information about the current slice which is required for synchronization. + auto init_slice = [&] () { + slice_iters = iters * (blockIdx.x + 1) - (k_tiles * slice_col_par + slice_row); + if (slice_iters < 0 || slice_col_par >= n_tiles * parallel) + slice_iters = 0; + if (slice_iters == 0) + return; + if (slice_row + slice_iters > k_tiles) + slice_iters = k_tiles - slice_row; + slice_count = 1; + slice_idx = 0; + int col_first = iters * ceildiv(k_tiles * slice_col_par, iters); + if (col_first <= k_tiles * (slice_col_par + 1)) { + int col_off = col_first - k_tiles * slice_col_par; + slice_count = ceildiv(k_tiles - col_off, iters); + if (col_off > 0) + slice_count++; + int delta_first = iters * blockIdx.x - col_first; + if (delta_first < 0 || (col_off == 0 && delta_first == 0)) + slice_idx = slice_count - 1; + else { + slice_idx = slice_count - 1 - delta_first / iters; + if (col_off > 0) + slice_idx--; + } + } + if (slice_col == n_tiles) { + A += 16 * thread_m_blocks * prob_k / 8; + C += 16 * thread_m_blocks * prob_n / 8; + locks += n_tiles; + slice_col = 0; + } + }; + init_slice(); + + int a_gl_stride = prob_k / 8; // stride of the A matrix in global memory + // We typically use `constexpr` to indicate that this value is a compile-time constant + constexpr int a_sh_stride = 16 * thread_k_blocks / 8; // stride of an A matrix tile in shared memory + constexpr int a_gl_rd_delta_o = 16 * thread_k_blocks / 8; // delta between subsequent A tiles in global memory + int a_gl_rd_delta_i = a_gl_stride * (threads / a_gl_rd_delta_o); // between subsequent accesses within a tile + constexpr int a_sh_wr_delta = a_sh_stride * (threads / a_gl_rd_delta_o); // between shared memory writes + constexpr int a_sh_rd_delta_o = 2 * ((threads / 32) / (thread_n_blocks / 4)); // between shared memory tile reads + constexpr int a_sh_rd_delta_i = a_sh_stride * 16; // within a shared memory tile + constexpr int a_sh_stage = a_sh_stride * (16 * thread_m_blocks); // overall size of a tile + constexpr int a_sh_wr_iters = ceildiv(a_sh_stage, a_sh_wr_delta); // number of shared write iterations for a tile + + int b_gl_stride = 16 * prob_n / 32; + constexpr int b_sh_stride = 32 * thread_n_blocks / 4; + int b_gl_rd_delta_o = b_gl_stride * thread_k_blocks; + int b_gl_rd_delta_i = b_gl_stride * (threads / b_sh_stride); + constexpr int b_sh_wr_delta = threads; + constexpr int b_sh_rd_delta = threads; + constexpr int b_sh_stage = b_sh_stride * thread_k_blocks; + constexpr int b_sh_wr_iters = b_sh_stage / b_sh_wr_delta; + + int s_gl_stride = prob_n / 8; + constexpr int s_sh_stride = 16 * thread_n_blocks / 8; + constexpr int s_sh_stage = s_sh_stride; + int s_gl_rd_delta = s_gl_stride; + + // Global A read index of current thread. + int a_gl_rd = a_gl_stride * (threadIdx.x / a_gl_rd_delta_o) + (threadIdx.x % a_gl_rd_delta_o); + a_gl_rd += a_gl_rd_delta_o * slice_row; + // Shared write index of current thread. + int a_sh_wr = a_sh_stride * (threadIdx.x / a_gl_rd_delta_o) + (threadIdx.x % a_gl_rd_delta_o); + // Shared read index. + int a_sh_rd = a_sh_stride * ((threadIdx.x % 32) % 16) + (threadIdx.x % 32) / 16; + a_sh_rd += 2 * ((threadIdx.x / 32) / (thread_n_blocks / 4)); + + int b_gl_rd = b_gl_stride * (threadIdx.x / b_sh_stride) + (threadIdx.x % b_sh_stride); + b_gl_rd += b_sh_stride * slice_col; + b_gl_rd += b_gl_rd_delta_o * slice_row; + int b_sh_wr = threadIdx.x; + int b_sh_rd = threadIdx.x; + + int s_gl_rd = s_gl_stride * ((thread_k_blocks * slice_row) / group_blocks) + s_sh_stride * slice_col + threadIdx.x; + int s_sh_wr = threadIdx.x; + int s_sh_rd; + // We use a different scale layout for grouped and column-wise quantization as we scale a `half2` tile in column-major + // layout in the former and in row-major in the latter case. + if (group_blocks != -1) + s_sh_rd = 8 * ((threadIdx.x / 32) % (thread_n_blocks / 4)) + (threadIdx.x % 32) / 4; + else + s_sh_rd = 8 * ((threadIdx.x / 32) % (thread_n_blocks / 4)) + (threadIdx.x % 32) % 4; + + // Precompute which thread should not read memory in which iterations; this is needed if there are more threads than + // required for a certain tilesize or when the batchsize is not a multiple of 16. + bool a_sh_wr_pred[a_sh_wr_iters]; + #pragma unroll + for (int i = 0; i < a_sh_wr_iters; i++) + a_sh_wr_pred[i] = a_sh_wr_delta * i + a_sh_wr < a_sh_stride * prob_m; + bool s_sh_wr_pred = threadIdx.x < s_sh_stride; + + // To ensure that writing and reading A tiles to/from shared memory, the latter in fragment format, is fully bank + // conflict free, we need to use a rather fancy XOR-based layout. The key here is that neither reads nor writes of + // the 16-byte `int4` blocks of 8 consecutive threads involve the same shared memory banks. Further, it seems (based + // on NSight-Compute) that each warp must also write a consecutive memory segment? + auto transform_a = [&] (int i) { + int row = i / a_gl_rd_delta_o; + return a_gl_rd_delta_o * row + (i % a_gl_rd_delta_o) ^ row; + }; + // Since the computation of this remapping is non-trivial and, due to our main loop unrolls, all shared memory + // accesses are static, we simply precompute both transformed reads and writes. + int a_sh_wr_trans[a_sh_wr_iters]; + #pragma unroll + for (int i = 0; i < a_sh_wr_iters; i++) + a_sh_wr_trans[i] = transform_a(a_sh_wr_delta * i + a_sh_wr); + int a_sh_rd_trans[b_sh_wr_iters][thread_m_blocks]; + #pragma unroll + for (int i = 0; i < b_sh_wr_iters; i++) { + #pragma unroll + for (int j = 0; j < thread_m_blocks; j++) + a_sh_rd_trans[i][j] = transform_a(a_sh_rd_delta_o * i + a_sh_rd_delta_i * j + a_sh_rd); + } + + // Since B-accesses have non-constant stride they have to be computed at runtime; we break dependicies between + // subsequent accesses with a tile by maintining multiple pointers (we have enough registers), a tiny optimization. + const int4* B_ptr[b_sh_wr_iters]; + #pragma unroll + for (int i = 0; i < b_sh_wr_iters; i++) + B_ptr[i] = B + b_gl_rd_delta_i * i + b_gl_rd; + + extern __shared__ int4 sh[]; + // Shared memory storage for global fetch pipelines. + int4* sh_a = sh; + int4* sh_b = sh_a + (stages * a_sh_stage); + int4* sh_s = sh_b + (stages * b_sh_stage); + // Register storage for double buffer of shared memory reads. + FragA frag_a[2][thread_m_blocks]; + I4 frag_b_quant[2]; + FragC frag_c[thread_m_blocks][4][2]; + FragS frag_s[2][4]; + + // Zero accumulators. + auto zero_accums = [&] () { + #pragma unroll + for (int i = 0; i < thread_m_blocks * 4 * 2 * 4; i++) + reinterpret_cast(frag_c)[i] = 0; + }; + + // Asynchronously fetch the next A, B and s tile from global to the next shared memory pipeline location. + auto fetch_to_shared = [&] (int pipe, int a_off, bool pred = true) { + if (pred) { + int4* sh_a_stage = sh_a + a_sh_stage * pipe; + #pragma unroll + for (int i = 0; i < a_sh_wr_iters; i++) { + cp_async4_pred( + &sh_a_stage[a_sh_wr_trans[i]], + &A[a_gl_rd_delta_i * i + a_gl_rd + a_gl_rd_delta_o * a_off], + a_sh_wr_pred[i] + ); + } + int4* sh_b_stage = sh_b + b_sh_stage * pipe; + #pragma unroll + for (int i = 0; i < b_sh_wr_iters; i++) { + cp_async4_stream(&sh_b_stage[b_sh_wr_delta * i + b_sh_wr], B_ptr[i]); + B_ptr[i] += b_gl_rd_delta_o; + } + // Only fetch scales if this tile starts a new group + if (group_blocks != -1 && pipe % (group_blocks / thread_k_blocks) == 0) { + int4* sh_s_stage = sh_s + s_sh_stage * pipe; + if (s_sh_wr_pred) + cp_async4_stream(&sh_s_stage[s_sh_wr], &s[s_gl_rd]); + s_gl_rd += s_gl_rd_delta; + } + } + // Insert a fence even when we are winding down the pipeline to ensure that waiting is also correct at this point. + cp_async_fence(); + }; + + // Wait until the next thread tile has been loaded to shared memory. + auto wait_for_stage = [&] () { + // We only have `stages - 2` active fetches since we are double buffering and can only issue the next fetch when + // it is guaranteed that the previous shared memory load is fully complete (as it may otherwise be overwritten). + cp_async_wait(); + __syncthreads(); + }; + + // Load the next sub-tile from the current location in the shared memory pipe into the current register buffer. + auto fetch_to_registers = [&] (int k, int pipe) { + // It may seem inefficient that we reload the groups for every sub-tile; however, this does not seem to be a + // significant bottleneck, while some theoretically better attempts have lead to bad instruction ordering by the + // compiler and correspondingly a noticable drop in performance. + if (group_blocks != -1) { + int4* sh_s_stage = sh_s + s_sh_stage * ((group_blocks / thread_k_blocks) * (pipe / (group_blocks / thread_k_blocks))); + reinterpret_cast(&frag_s[k % 2])[0] = sh_s_stage[s_sh_rd]; + } + int4* sh_a_stage = sh_a + a_sh_stage * pipe; + #pragma unroll + for (int i = 0; i < thread_m_blocks; i++) + ldsm4(frag_a[k % 2][i], &sh_a_stage[a_sh_rd_trans[k % b_sh_wr_iters][i]]); + int4* sh_b_stage = sh_b + b_sh_stage * pipe; + frag_b_quant[k % 2] = *reinterpret_cast(&sh_b_stage[b_sh_rd_delta * (k % b_sh_wr_iters) + b_sh_rd]); + }; + + // Execute the actual tensor core matmul of a sub-tile. + auto matmul = [&] (int k) { + // We have the m dimension as the inner loop in order to encourage overlapping dequantization and matmul operations. + #pragma unroll + for (int j = 0; j < 4; j++) { + int b_quant = frag_b_quant[k % 2][j]; + int b_quant_shift = b_quant >> 8; + FragB frag_b0 = dequant(b_quant); + // If there are no groups, we can just scale the final output once and can avoid doing so for each weight. + if (group_blocks != -1) + scale(frag_b0, frag_s[k % 2][j], 0); + FragB frag_b1 = dequant(b_quant_shift); + if (group_blocks != -1) + scale(frag_b1, frag_s[k % 2][j], 1); + #pragma unroll + for (int i = 0; i < thread_m_blocks; i++) { + mma(frag_a[k % 2][i], frag_b0, frag_c[i][j][0]); + mma(frag_a[k % 2][i], frag_b1, frag_c[i][j][1]); + } + } + }; + + // Since we slice across the k dimension of a tile in order to increase the number of warps while keeping the n + // dimension of a tile reasonable, we have multiple warps that accumulate their partial sums of the same output + // location; which we have to reduce over in the end. We do in shared memory. + auto thread_block_reduce = [&] () { + constexpr int red_off = threads / b_sh_stride / 2; + if (red_off >= 1) { + int red_idx = threadIdx.x / b_sh_stride; + constexpr int red_sh_stride = b_sh_stride * 4 * 2; + constexpr int red_sh_delta = b_sh_stride; + int red_sh_rd = red_sh_stride * (threadIdx.x / b_sh_stride) + (threadIdx.x % b_sh_stride); + + // Parallel logarithmic shared memory reduction. We make sure to avoid any unnecessary read or write iterations, + // e.g., for two warps we write only once by warp 1 and read only once by warp 0. + + #pragma unroll + for (int m_block = 0; m_block < thread_m_blocks; m_block++) { + #pragma unroll + for (int i = red_off; i > 0; i /= 2) { + if (i <= red_idx && red_idx < 2 * i) { + #pragma unroll + for (int j = 0; j < 4 * 2; j++) { + int red_sh_wr = red_sh_delta * j + (red_sh_rd - red_sh_stride * i); + if (i < red_off) { + float* c_rd = reinterpret_cast(&sh[red_sh_delta * j + red_sh_rd]); + float* c_wr = reinterpret_cast(&sh[red_sh_wr]); + #pragma unroll + for (int k = 0; k < 4; k++) + reinterpret_cast(frag_c)[4 * 2 * m_block + j][k] += c_rd[k] + c_wr[k]; + } + sh[red_sh_wr] = reinterpret_cast(&frag_c)[4 * 2 * m_block + j]; + } + } + __syncthreads(); + } + if (red_idx == 0) { + #pragma unroll + for (int i = 0; i < 4 * 2; i++) { + float* c_rd = reinterpret_cast(&sh[red_sh_delta * i + red_sh_rd]); + #pragma unroll + for (int j = 0; j < 4; j++) + reinterpret_cast(frag_c)[4 * 2 * m_block + i][j] += c_rd[j]; + } + } + __syncthreads(); + } + } + }; + + // Since multiple threadblocks may process parts of the same column slice, we finally have to globally reduce over + // the results. As the striped partioning minimizes the number of such reductions and our outputs are usually rather + // small, we perform this reduction serially in L2 cache. + auto global_reduce = [&] (bool first = false, bool last = false) { + // We are very careful here to reduce directly in the output buffer to maximize L2 cache utilization in this step. + // To do this, we write out results in FP16 (but still reduce with FP32 compute). + constexpr int active_threads = 32 * thread_n_blocks / 4; + if (threadIdx.x < active_threads) { + int c_gl_stride = prob_n / 8; + int c_gl_wr_delta_o = 8 * c_gl_stride; + int c_gl_wr_delta_i = 4 * (active_threads / 32); + int c_gl_wr = c_gl_stride * ((threadIdx.x % 32) / 4) + 4 * (threadIdx.x / 32) + threadIdx.x % 4; + c_gl_wr += (2 * thread_n_blocks) * slice_col; + constexpr int c_sh_wr_delta = active_threads; + int c_sh_wr = threadIdx.x; + + int row = (threadIdx.x % 32) / 4; + + if (!first) { + // Interestingly, doing direct global accesses here really seems to mess up the compiler and lead to slowdowns, + // hence we also use async-copies even though these fetches are not actually asynchronous. + #pragma unroll + for (int i = 0; i < thread_m_blocks * 4; i++) { + cp_async4_pred( + &sh[c_sh_wr + c_sh_wr_delta * i], + &C[c_gl_wr + c_gl_wr_delta_o * (i / 2) + c_gl_wr_delta_i * (i % 2)], + i < (thread_m_blocks - 1) * 4 || 8 * (i / 2) + row < prob_m + ); + } + cp_async_fence(); + cp_async_wait<0>(); + } + + #pragma unroll + for (int i = 0; i < thread_m_blocks * 4; i++) { + if (i < (thread_m_blocks - 1) * 4 || 8 * (i / 2) + row < prob_m) { + if (!first) { + int4 c_red = sh[c_sh_wr + i * c_sh_wr_delta]; + #pragma unroll + for (int j = 0; j < 2 * 4; j++) { + reinterpret_cast(&frag_c)[4 * 2 * 4 * (i / 4) + 4 * j + (i % 4)] += __half2float( + reinterpret_cast<__half*>(&c_red)[j] + ); + } + } + if (!last) { + int4 c; + #pragma unroll + for (int j = 0; j < 2 * 4; j++) { + reinterpret_cast<__half*>(&c)[j] = __float2half( + reinterpret_cast(&frag_c)[4 * 2 * 4 * (i / 4) + 4 * j + (i % 4)] + ); + } + C[c_gl_wr + c_gl_wr_delta_o * (i / 2) + c_gl_wr_delta_i * (i % 2)] = c; + } + } + } + } + }; + + // Write out the reduce final result in the correct layout. We only actually reshuffle matrix fragments in this step, + // the reduction above is performed in fragment layout. + auto write_result = [&] () { + int c_gl_stride = prob_n / 8; + constexpr int c_sh_stride = 2 * thread_n_blocks + 1; + int c_gl_wr_delta = c_gl_stride * (threads / (2 * thread_n_blocks)); + constexpr int c_sh_rd_delta = c_sh_stride * (threads / (2 * thread_n_blocks)); + + int c_gl_wr = c_gl_stride * (threadIdx.x / (2 * thread_n_blocks)) + (threadIdx.x % (2 * thread_n_blocks)); + c_gl_wr += (2 * thread_n_blocks) * slice_col; + int c_sh_wr = (4 * c_sh_stride) * ((threadIdx.x % 32) / 4) + (threadIdx.x % 32) % 4; + c_sh_wr += 32 * (threadIdx.x / 32); + int c_sh_rd = c_sh_stride * (threadIdx.x / (2 * thread_n_blocks)) + (threadIdx.x % (2 * thread_n_blocks)); + + int c_gl_wr_end = c_gl_stride * prob_m; + + // We first reorder in shared memory to guarantee the most efficient final global write patterns + auto write = [&] (int idx, float c0, float c1, FragS& s) { + half2 res = __halves2half2(__float2half(c0), __float2half(c1)); + if (group_blocks == -1) // for per-column quantization we finally apply the scale here + res = __hmul2(res, s[0]); + ((half2*) sh)[idx] = res; + }; + if (threadIdx.x / 32 < thread_n_blocks / 4) { + #pragma unroll + for (int i = 0; i < thread_m_blocks; i++) { + #pragma unroll + for (int j = 0; j < 4; j++) { + int wr = c_sh_wr + 8 * j; + write(wr + (4 * c_sh_stride) * 0 + 0, frag_c[i][j][0][0], frag_c[i][j][0][1], frag_s[j / 2][2 * (j % 2) + 0]); + write(wr + (4 * c_sh_stride) * 8 + 0, frag_c[i][j][0][2], frag_c[i][j][0][3], frag_s[j / 2][2 * (j % 2) + 0]); + write(wr + (4 * c_sh_stride) * 0 + 4, frag_c[i][j][1][0], frag_c[i][j][1][1], frag_s[j / 2][2 * (j % 2) + 1]); + write(wr + (4 * c_sh_stride) * 8 + 4, frag_c[i][j][1][2], frag_c[i][j][1][3], frag_s[j / 2][2 * (j % 2) + 1]); + } + c_sh_wr += 16 * (4 * c_sh_stride); + } + } + __syncthreads(); + + #pragma unroll + for (int i = 0; i < ceildiv(16 * thread_m_blocks, threads / (2 * thread_n_blocks)); i++) { + if (c_gl_wr < c_gl_wr_end) { + C[c_gl_wr] = sh[c_sh_rd]; + c_gl_wr += c_gl_wr_delta; + c_sh_rd += c_sh_rd_delta; + } + } + }; + + // Start global fetch and register load pipelines. + auto start_pipes = [&] () { + #pragma unroll + for (int i = 0; i < stages - 1; i++) + fetch_to_shared(i, i, i < slice_iters); + zero_accums(); + wait_for_stage(); + fetch_to_registers(0, 0); + a_gl_rd += a_gl_rd_delta_o * (stages - 1); + }; + start_pipes(); + + // Main loop. + while (slice_iters) { + // We unroll over both the global fetch and the register load pipeline to ensure all shared memory accesses are + // static. Note that both pipelines have even length meaning that the next iteration will always start at index 0. + #pragma unroll + for (int pipe = 0; pipe < stages;) { + #pragma unroll + for (int k = 0; k < b_sh_wr_iters; k++) { + fetch_to_registers(k + 1, pipe % stages); + if (k == b_sh_wr_iters - 2) { + fetch_to_shared((pipe + stages - 1) % stages, pipe, slice_iters >= stages); + pipe++; + wait_for_stage(); + } + matmul(k); + } + slice_iters--; + if (slice_iters == 0) + break; + } + a_gl_rd += a_gl_rd_delta_o * stages; + + // Process results and, if necessary, proceed to the next column slice. While this pattern may not be the most + // readable, other ways of writing the loop seemed to noticeably worse performance after compliation. + if (slice_iters == 0) { + cp_async_wait<0>(); + bool last = slice_idx == slice_count - 1; + // For per-column scales, we only fetch them here in the final step before write-out + if (group_blocks == -1 && last) { + if (s_sh_wr_pred) + cp_async4_stream(&sh_s[s_sh_wr], &s[s_gl_rd]); + cp_async_fence(); + } + thread_block_reduce(); + if (group_blocks == -1 && last) { + cp_async_wait<0>(); + __syncthreads(); + if (threadIdx.x / 32 < thread_n_blocks / 4) { + reinterpret_cast(&frag_s)[0] = sh_s[s_sh_rd + 0]; + reinterpret_cast(&frag_s)[1] = sh_s[s_sh_rd + 4]; + } + } + if (slice_count > 1) { // only globally reduce if there is more than one block in a slice + barrier_acquire(&locks[slice_col], slice_idx); + global_reduce(slice_idx == 0, last); + barrier_release(&locks[slice_col], last); + } + if (last) // only the last block in a slice actually writes the result + write_result(); + slice_row = 0; + slice_col_par++; + slice_col++; + init_slice(); + if (slice_iters) { + a_gl_rd = a_gl_stride * (threadIdx.x / a_gl_rd_delta_o) + (threadIdx.x % a_gl_rd_delta_o); + #pragma unroll + for (int i = 0; i < b_sh_wr_iters; i++) + B_ptr[i] += b_sh_stride - b_gl_rd_delta_o * k_tiles; + if (slice_col == 0) { + #pragma unroll + for (int i = 0; i < b_sh_wr_iters; i++) + B_ptr[i] -= b_gl_stride; + } + s_gl_rd = s_sh_stride * slice_col + threadIdx.x; + start_pipes(); + } + } + } +} + + +// 8 warps are a good choice since every SM has 4 schedulers and having more than 1 warp per schedule allows some more +// latency hiding. At the same time, we want relatively few warps to have many registers per warp and small tiles. +const int THREADS = 256; +const int STAGES = 4; // 4 pipeline stages fit into shared memory +const int SHARED_MEM = 96 * 1024; // max shared memory on compute capability 8.6 (< 8.0) + +#define CALL_IF(THREAD_M_BLOCKS, THREAD_N_BLOCKS, THREAD_K_BLOCKS, GROUP_BLOCKS) \ + else if ( \ + thread_m_blocks == THREAD_M_BLOCKS && thread_n_blocks == THREAD_N_BLOCKS && thread_k_blocks == THREAD_K_BLOCKS && \ + group_blocks == GROUP_BLOCKS \ + ) { \ + cudaFuncSetAttribute( \ + Marlin, \ + cudaFuncAttributeMaxDynamicSharedMemorySize, \ + SHARED_MEM \ + ); \ + Marlin< \ + THREADS, THREAD_M_BLOCKS, THREAD_N_BLOCKS, THREAD_K_BLOCKS, STAGES, GROUP_BLOCKS \ + ><<>>( \ + A_ptr, B_ptr, C_ptr, s_ptr, \ + prob_m, prob_n, prob_k, \ + locks \ + ); \ + } + +const int ERR_PROB_SHAPE = 1; +const int ERR_KERN_SHAPE = 2; + +int marlin_cuda( + const void* A, + const void* B, + void* C, + void* s, + int prob_m, + int prob_n, + int prob_k, + void* workspace, + int groupsize = -1, + int dev = 0, + cudaStream_t stream = 0, + int thread_k = -1, + int thread_n = -1, + int sms = -1, + int max_par = 16 +) { + int tot_m = prob_m; + int tot_m_blocks = ceildiv(tot_m, 16); + int pad = 16 * tot_m_blocks - tot_m; + + if (sms == -1) + cudaDeviceGetAttribute(&sms, cudaDevAttrMultiProcessorCount, dev); + if (thread_k == -1 || thread_n == -1) { + if (prob_m <= 16) { + // For small batchizes, better partioning is slightly more important than better compute utilization + thread_k = 128; + thread_n = 128; + } else { + thread_k = 64; + thread_n = 256; + } + } + + int thread_k_blocks = thread_k / 16; + int thread_n_blocks = thread_n / 16; + int group_blocks = (groupsize == -1) ? -1 : groupsize / 16; + int blocks = sms; + + if (prob_n % thread_n != 0 || prob_k % thread_k != 0 || (group_blocks != -1 && prob_k % group_blocks != 0)) + return ERR_PROB_SHAPE; + if (prob_m == 0 || prob_n == 0 || prob_k == 0) + return 0; + + const int4* A_ptr = (const int4*) A; + const int4* B_ptr = (const int4*) B; + int4* C_ptr = (int4*) C; + const int4* s_ptr = (const int4*) s; + + int cols = prob_n / thread_n; + int* locks = (int*) workspace; + + int ret = 0; + for (int i = 0; i < tot_m_blocks; i += 4) { + int thread_m_blocks = tot_m_blocks - i; + prob_m = tot_m - 16 * i; + int par = 1; + if (thread_m_blocks > 4) { + // Note that parallel > 1 currently only works for inputs without any padding + par = (16 * thread_m_blocks - pad) / 64; + if (par > max_par) + par = max_par; + prob_m = 64 * par; + i += 4 * (par - 1); + thread_m_blocks = 4; + } + + // For compilation speed, we only define the kernel configurations that have seemed useful (in terms of performance) + // in our testing, however many more are, in principle, possible. + if (false) {} + CALL_IF(1, 8, 8, -1) + CALL_IF(1, 8, 8, 8) + CALL_IF(1, 16, 4, -1) + CALL_IF(1, 16, 4, 8) + CALL_IF(2, 16, 4, -1) + CALL_IF(2, 16, 4, 8) + CALL_IF(3, 16, 4, -1) + CALL_IF(3, 16, 4, 8) + CALL_IF(4, 16, 4, -1) + CALL_IF(4, 16, 4, 8) + else + ret = ERR_KERN_SHAPE; + + A_ptr += 16 * thread_m_blocks * (prob_k / 8) * par; + C_ptr += 16 * thread_m_blocks * (prob_n / 8) * par; + } + + return ret; +} + + +#endif diff --git a/testing/cpp/efficient_i4_cuda_impl/param_permutate.cpp b/testing/cpp/efficient_i4_cuda_impl/param_permutate.cpp new file mode 100644 index 000000000..64248b3d1 --- /dev/null +++ b/testing/cpp/efficient_i4_cuda_impl/param_permutate.cpp @@ -0,0 +1,89 @@ +#include +#include +#include +#include +#include +#include + +// Helper function to interleave the perm array +std::vector interleave_perms(const std::vector& perm) { + std::vector interleaved_perm; + std::array interleave = {0, 2, 4, 6, 1, 3, 5, 7}; + + int num_rows = perm.size() / 8; + for (int i = 0; i < num_rows; ++i) { + std::array row; + std::copy(perm.begin() + i * 8, perm.begin() + (i + 1) * 8, row.begin()); + for (int j : interleave) { + interleaved_perm.push_back(row[j]); + } + } + + return interleaved_perm; +} + +std::tuple, std::vector, std::vector> get_perms() { + std::vector perm; + + for (int i = 0; i < 32; ++i) { + std::vector perm1; + int col = i / 4; + for (int block : {0, 1}) { + for (int row : { + 2 * (i % 4), + 2 * (i % 4) + 1, + 2 * (i % 4 + 4), + 2 * (i % 4 + 4) + 1 + }) { + perm1.push_back(16 * row + col + 8 * block); + } + } + for (int j = 0; j < 4; ++j) { + for (int p : perm1) { + perm.push_back(p + 256 * j); + } + } + } + + // Interleave the perm array + perm = interleave_perms(perm); + + std::vector scale_perm; + for (int i = 0; i < 8; ++i) { + for (int j = 0; j < 8; ++j) { + scale_perm.push_back(i + 8 * j); + } + } + + std::vector scale_perm_single; + for (int i = 0; i < 4; ++i) { + for (int j : {0, 1, 8, 9, 16, 17, 24, 25}) { + scale_perm_single.push_back(2 * i + j); + } + } + + return std::make_tuple(perm, scale_perm, scale_perm_single); +} + +TEST(EfficientI4MatmulTest, ParamPermutate) +{ + auto [perm, scale_perm, scale_perm_single] = get_perms(); + + std::cout << "perm: "; + for (int i = 0; i < 10; ++i) { + std::cout << perm[i] << " "; + } + std::cout << std::endl; + + std::cout << "scale_perm: "; + for (const auto& val : scale_perm) { + std::cout << val << " "; + } + std::cout << std::endl; + + std::cout << "scale_perm_single: "; + for (const auto& val : scale_perm_single) { + std::cout << val << " "; + } + std::cout << std::endl; +} From 6b73a210ff846fb55bed253b5c0a1c089c3e95f1 Mon Sep 17 00:00:00 2001 From: LeiWang1999 Date: Sun, 7 Jul 2024 16:39:04 +0000 Subject: [PATCH 13/17] Refactor copyright notice in i4matmul.hpp --- THIRDPARTYNOTICES.txt | 204 ++++++++++++++++++ .../cpp/efficient_i4_cuda_impl/i4matmul.hpp | 36 ++-- 2 files changed, 224 insertions(+), 16 deletions(-) diff --git a/THIRDPARTYNOTICES.txt b/THIRDPARTYNOTICES.txt index f377e67bb..d959effbb 100644 --- a/THIRDPARTYNOTICES.txt +++ b/THIRDPARTYNOTICES.txt @@ -206,3 +206,207 @@ Notice for apache/tvm limitations under the License. ------------------------------------------------------------------------------------ +Notice for IST-DASLab/marlin/ +------------------------------- + Apache License + Version 2.0, January 2004 + http://www.apache.org/licenses/ + + TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION + + 1. Definitions. + + "License" shall mean the terms and conditions for use, reproduction, + and distribution as defined by Sections 1 through 9 of this document. + + "Licensor" shall mean the copyright owner or entity authorized by + the copyright owner that is granting the License. + + "Legal Entity" shall mean the union of the acting entity and all + other entities that control, are controlled by, or are under common + control with that entity. For the purposes of this definition, + "control" means (i) the power, direct or indirect, to cause the + direction or management of such entity, whether by contract or + otherwise, or (ii) ownership of fifty percent (50%) or more of the + outstanding shares, or (iii) beneficial ownership of such entity. + + "You" (or "Your") shall mean an individual or Legal Entity + exercising permissions granted by this License. + + "Source" form shall mean the preferred form for making modifications, + including but not limited to software source code, documentation + source, and configuration files. + + "Object" form shall mean any form resulting from mechanical + transformation or translation of a Source form, including but + not limited to compiled object code, generated documentation, + and conversions to other media types. + + "Work" shall mean the work of authorship, whether in Source or + Object form, made available under the License, as indicated by a + copyright notice that is included in or attached to the work + (an example is provided in the Appendix below). + + "Derivative Works" shall mean any work, whether in Source or Object + form, that is based on (or derived from) the Work and for which the + editorial revisions, annotations, elaborations, or other modifications + represent, as a whole, an original work of authorship. For the purposes + of this License, Derivative Works shall not include works that remain + separable from, or merely link (or bind by name) to the interfaces of, + the Work and Derivative Works thereof. + + "Contribution" shall mean any work of authorship, including + the original version of the Work and any modifications or additions + to that Work or Derivative Works thereof, that is intentionally + submitted to Licensor for inclusion in the Work by the copyright owner + or by an individual or Legal Entity authorized to submit on behalf of + the copyright owner. For the purposes of this definition, "submitted" + means any form of electronic, verbal, or written communication sent + to the Licensor or its representatives, including but not limited to + communication on electronic mailing lists, source code control systems, + and issue tracking systems that are managed by, or on behalf of, the + Licensor for the purpose of discussing and improving the Work, but + excluding communication that is conspicuously marked or otherwise + designated in writing by the copyright owner as "Not a Contribution." + + "Contributor" shall mean Licensor and any individual or Legal Entity + on behalf of whom a Contribution has been received by Licensor and + subsequently incorporated within the Work. + + 2. Grant of Copyright License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + copyright license to reproduce, prepare Derivative Works of, + publicly display, publicly perform, sublicense, and distribute the + Work and such Derivative Works in Source or Object form. + + 3. Grant of Patent License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + (except as stated in this section) patent license to make, have made, + use, offer to sell, sell, import, and otherwise transfer the Work, + where such license applies only to those patent claims licensable + by such Contributor that are necessarily infringed by their + Contribution(s) alone or by combination of their Contribution(s) + with the Work to which such Contribution(s) was submitted. If You + institute patent litigation against any entity (including a + cross-claim or counterclaim in a lawsuit) alleging that the Work + or a Contribution incorporated within the Work constitutes direct + or contributory patent infringement, then any patent licenses + granted to You under this License for that Work shall terminate + as of the date such litigation is filed. + + 4. Redistribution. You may reproduce and distribute copies of the + Work or Derivative Works thereof in any medium, with or without + modifications, and in Source or Object form, provided that You + meet the following conditions: + + (a) You must give any other recipients of the Work or + Derivative Works a copy of this License; and + + (b) You must cause any modified files to carry prominent notices + stating that You changed the files; and + + (c) You must retain, in the Source form of any Derivative Works + that You distribute, all copyright, patent, trademark, and + attribution notices from the Source form of the Work, + excluding those notices that do not pertain to any part of + the Derivative Works; and + + (d) If the Work includes a "NOTICE" text file as part of its + distribution, then any Derivative Works that You distribute must + include a readable copy of the attribution notices contained + within such NOTICE file, excluding those notices that do not + pertain to any part of the Derivative Works, in at least one + of the following places: within a NOTICE text file distributed + as part of the Derivative Works; within the Source form or + documentation, if provided along with the Derivative Works; or, + within a display generated by the Derivative Works, if and + wherever such third-party notices normally appear. The contents + of the NOTICE file are for informational purposes only and + do not modify the License. You may add Your own attribution + notices within Derivative Works that You distribute, alongside + or as an addendum to the NOTICE text from the Work, provided + that such additional attribution notices cannot be construed + as modifying the License. + + You may add Your own copyright statement to Your modifications and + may provide additional or different license terms and conditions + for use, reproduction, or distribution of Your modifications, or + for any such Derivative Works as a whole, provided Your use, + reproduction, and distribution of the Work otherwise complies with + the conditions stated in this License. + + 5. Submission of Contributions. Unless You explicitly state otherwise, + any Contribution intentionally submitted for inclusion in the Work + by You to the Licensor shall be under the terms and conditions of + this License, without any additional terms or conditions. + Notwithstanding the above, nothing herein shall supersede or modify + the terms of any separate license agreement you may have executed + with Licensor regarding such Contributions. + + 6. Trademarks. This License does not grant permission to use the trade + names, trademarks, service marks, or product names of the Licensor, + except as required for reasonable and customary use in describing the + origin of the Work and reproducing the content of the NOTICE file. + + 7. Disclaimer of Warranty. Unless required by applicable law or + agreed to in writing, Licensor provides the Work (and each + Contributor provides its Contributions) on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or + implied, including, without limitation, any warranties or conditions + of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A + PARTICULAR PURPOSE. You are solely responsible for determining the + appropriateness of using or redistributing the Work and assume any + risks associated with Your exercise of permissions under this License. + + 8. Limitation of Liability. In no event and under no legal theory, + whether in tort (including negligence), contract, or otherwise, + unless required by applicable law (such as deliberate and grossly + negligent acts) or agreed to in writing, shall any Contributor be + liable to You for damages, including any direct, indirect, special, + incidental, or consequential damages of any character arising as a + result of this License or out of the use or inability to use the + Work (including but not limited to damages for loss of goodwill, + work stoppage, computer failure or malfunction, or any and all + other commercial damages or losses), even if such Contributor + has been advised of the possibility of such damages. + + 9. Accepting Warranty or Additional Liability. While redistributing + the Work or Derivative Works thereof, You may choose to offer, + and charge a fee for, acceptance of support, warranty, indemnity, + or other liability obligations and/or rights consistent with this + License. However, in accepting such obligations, You may act only + on Your own behalf and on Your sole responsibility, not on behalf + of any other Contributor, and only if You agree to indemnify, + defend, and hold each Contributor harmless for any liability + incurred by, or claims asserted against, such Contributor by reason + of your accepting any such warranty or additional liability. + + END OF TERMS AND CONDITIONS + + APPENDIX: How to apply the Apache License to your work. + + To apply the Apache License to your work, attach the following + boilerplate notice, with the fields enclosed by brackets "[]" + replaced with your own identifying information. (Don't include + the brackets!) The text should be enclosed in the appropriate + comment syntax for the file format. We also recommend that a + file or class name and description of purpose be included on the + same "printed page" as the copyright notice for easier + identification within third-party archives. + + Copyright [yyyy] [name of copyright owner] + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. +------------------------------------------------------------------------------------ diff --git a/testing/cpp/efficient_i4_cuda_impl/i4matmul.hpp b/testing/cpp/efficient_i4_cuda_impl/i4matmul.hpp index ae4cef5a2..a12a57dcd 100644 --- a/testing/cpp/efficient_i4_cuda_impl/i4matmul.hpp +++ b/testing/cpp/efficient_i4_cuda_impl/i4matmul.hpp @@ -1,19 +1,23 @@ -/* - * Copyright (C) Marlin.2024 Elias Frantar (elias.frantar@ist.ac.at) - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - +// Copyright 2018 The apache/tvm Authors. All Rights Reserved. +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. +// +// Modifications Copyright (c) Microsoft. +// The code below is mostly copied from marlin_cuda in IST-DASLab/marlin. #ifndef MARLIN_CUDA_KERNEL_CUH #define MARLIN_CUDA_KERNEL_CUH From 086d208fd07984c8cf876f1529fc80ade4cff21d Mon Sep 17 00:00:00 2001 From: LeiWang1999 Date: Sun, 7 Jul 2024 16:54:51 +0000 Subject: [PATCH 14/17] Refactor BitBLASLinear test module for improved readability and maintainability --- testing/python/module/test_bitblas_linear.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/testing/python/module/test_bitblas_linear.py b/testing/python/module/test_bitblas_linear.py index 3da4a73b6..f329a146e 100644 --- a/testing/python/module/test_bitblas_linear.py +++ b/testing/python/module/test_bitblas_linear.py @@ -10,6 +10,7 @@ torch.manual_seed(0) bitblas.set_log_level("DEBUG") + def correctness_consistent(m, in_features, out_features, bias): linear_torch = (nn.Linear(in_features, out_features, bias=bias).to(torch.float16).cuda()) linear_bitblas = BitBLASLinear( @@ -44,6 +45,7 @@ def test_correctness_consistent(): correctness_consistent(1024, 1024, 1024, True) correctness_consistent([1, 1024], 1024, 1024, True) + def correctness_weight_only_dequantize( m, in_features, From 47a3abdb805d2d27241888dce310c9f08f0771fb Mon Sep 17 00:00:00 2001 From: LeiWang1999 Date: Mon, 8 Jul 2024 04:47:15 +0000 Subject: [PATCH 15/17] refactor test as version below python 3.9 cannot handle int32 overflow. --- .../test_int4b_fp16_convert.py | 30 +++++++++---------- 1 file changed, 15 insertions(+), 15 deletions(-) diff --git a/testing/python/type_conversion/test_int4b_fp16_convert.py b/testing/python/type_conversion/test_int4b_fp16_convert.py index 92b0e0788..c1ed480fb 100644 --- a/testing/python/type_conversion/test_int4b_fp16_convert.py +++ b/testing/python/type_conversion/test_int4b_fp16_convert.py @@ -44,25 +44,25 @@ def interleave_weight(qweight, nbits=4, target_dtype="float16"): if nbits == 1 and target_dtype == "int8": # special handling for 1b interleave - n16_weight = new_qweight & np.int32(0xF0F00F0F) - n16_weight |= ((new_qweight & np.int32(0x000000F0)) >> 4) << 16 - n16_weight |= ((new_qweight & np.int32(0x0000F000)) >> 12) << 24 - n16_weight |= ((new_qweight & np.int32(0x000F0000)) >> 16) << 4 - n16_weight |= ((new_qweight & np.int32(0x0F000000)) >> 24) << 12 + n16_weight = new_qweight & np.int32(np.uint32(0xF0F00F0F)) + n16_weight |= ((new_qweight & np.int32(np.uint32(0x000000F0))) >> 4) << 16 + n16_weight |= ((new_qweight & np.int32(np.uint32(0x0000F000))) >> 12) << 24 + n16_weight |= ((new_qweight & np.int32(np.uint32(0x000F0000))) >> 16) << 4 + n16_weight |= ((new_qweight & np.int32(np.uint32(0x0F000000))) >> 24) << 12 return n16_weight.view(np.int8) elif nbits == 2 and target_dtype == "float16": - n8_weight = new_qweight & np.int32(0xFF0000FF) - n8_weight |= ((new_qweight & np.int32(0x0000FF00)) >> 8) << 16 - n8_weight |= ((new_qweight & np.int32(0x00FF0000)) >> 16) << 8 + n8_weight = new_qweight & np.int32(np.uint32(0xFF0000FF)) + n8_weight |= ((new_qweight & np.int32(np.uint32(0x0000FF00))) >> 8) << 16 + n8_weight |= ((new_qweight & np.int32(np.uint32(0x00FF0000))) >> 16) << 8 return n8_weight.view(np.int8) elif nbits == 1 and target_dtype == "float16": - n8_weight = new_qweight & 0xF000000F - n8_weight |= ((new_qweight & 0x000000F0) >> 4) << 8 - n8_weight |= ((new_qweight & 0x00000F00) >> 8) << 16 - n8_weight |= ((new_qweight & 0x0000F000) >> 12) << 24 - n8_weight |= ((new_qweight & 0x000F0000) >> 16) << 4 - n8_weight |= ((new_qweight & 0x00F00000) >> 20) << 12 - n8_weight |= ((new_qweight & 0x0F000000) >> 24) << 20 + n8_weight = new_qweight & np.int32(np.uint32(0xF000000F)) + n8_weight |= ((new_qweight & np.int32(np.uint32(0x000000F0))) >> 4) << 8 + n8_weight |= ((new_qweight & np.int32(np.uint32(0x00000F00))) >> 8) << 16 + n8_weight |= ((new_qweight & np.int32(np.uint32(0x0000F000))) >> 12) << 24 + n8_weight |= ((new_qweight & np.int32(np.uint32(0x000F0000))) >> 16) << 4 + n8_weight |= ((new_qweight & np.int32(np.uint32(0x00F00000))) >> 20) << 12 + n8_weight |= ((new_qweight & np.int32(np.uint32(0x0F000000))) >> 24) << 20 return new_qweight.view(np.int8) From 024b2474b8b81a3f855be5e6ccfcd05139bd285a Mon Sep 17 00:00:00 2001 From: LeiWang1999 Date: Mon, 8 Jul 2024 04:47:36 +0000 Subject: [PATCH 16/17] format lint for test --- .../test_int4b_fp16_convert.py | 56 +++++-------------- 1 file changed, 15 insertions(+), 41 deletions(-) diff --git a/testing/python/type_conversion/test_int4b_fp16_convert.py b/testing/python/type_conversion/test_int4b_fp16_convert.py index c1ed480fb..2af765047 100644 --- a/testing/python/type_conversion/test_int4b_fp16_convert.py +++ b/testing/python/type_conversion/test_int4b_fp16_convert.py @@ -21,9 +21,7 @@ def general_compress_to_int8(lowprecision_weight, source_bits=4): ) for j in range(lowprecision_weight.shape[-1] // elems_per_byte): for k in range(elems_per_byte): - int8_weight[:, j] |= lowprecision_weight[:, j * elems_per_byte + k] << ( - source_bits * k - ) + int8_weight[:, j] |= lowprecision_weight[:, j * elems_per_byte + k] << (source_bits * k) return int8_weight @@ -80,17 +78,11 @@ def interleave_weight(A: T.Buffer((N, QK), "int32"), B: T.Buffer((N, QK), "int32 with T.block("B"): v0, v1, v2, v3 = T.axis.remap("SSSS", [ax0, ax1, ax2, ax3]) offset = v2 * elems_per_group + v3 - shift = (offset % num_groups) * bits_stride + ( - offset // num_groups - ) * bits - B[v0, v1] = B[v0, v1] | ( - ((A[v0, v1] >> (bits * offset)) & mask) << shift - ) + shift = (offset % num_groups) * bits_stride + (offset // num_groups) * bits + B[v0, v1] = B[v0, v1] | (((A[v0, v1] >> (bits * offset)) & mask) << shift) @T.prim_func - def interleave_weight_f16_2b( - A: T.Buffer((N, QK), "int32"), B: T.Buffer((N, QK), "int32") - ): + def interleave_weight_f16_2b(A: T.Buffer((N, QK), "int32"), B: T.Buffer((N, QK), "int32")): B_tmp_1 = T.alloc_buffer((N, QK), "int32", scope="local") B_tmp_2 = T.alloc_buffer((N, QK), "int32", scope="local") B_tmp_3 = T.alloc_buffer((N, QK), "int32", scope="local") @@ -98,12 +90,8 @@ def interleave_weight_f16_2b( with T.block("B_tmp"): v0, v1, v2, v3 = T.axis.remap("SSSS", [ax0, ax1, ax2, ax3]) offset = v2 * elems_per_group + v3 - shift = (offset % num_groups) * bits_stride + ( - offset // num_groups - ) * bits - B[v0, v1] = B[v0, v1] | ( - ((A[v0, v1] >> (bits * offset)) & mask) << shift - ) + shift = (offset % num_groups) * bits_stride + (offset // num_groups) * bits + B[v0, v1] = B[v0, v1] | (((A[v0, v1] >> (bits * offset)) & mask) << shift) for ax0, ax1 in T.grid(N, QK): with T.block("B"): @@ -114,9 +102,7 @@ def interleave_weight_f16_2b( B[v0, v1] = B_tmp_1[v0, v1] | B_tmp_2[v0, v1] | B_tmp_3[v0, v1] @T.prim_func - def interleave_weight_f16_1b( - A: T.Buffer((N, QK), "int32"), B: T.Buffer((N, QK), "int32") - ): + def interleave_weight_f16_1b(A: T.Buffer((N, QK), "int32"), B: T.Buffer((N, QK), "int32")): B_tmp_1 = T.alloc_buffer((N, QK), "int32", scope="local") B_tmp_2 = T.alloc_buffer((N, QK), "int32", scope="local") B_tmp_3 = T.alloc_buffer((N, QK), "int32", scope="local") @@ -128,12 +114,8 @@ def interleave_weight_f16_1b( with T.block("B_tmp"): v0, v1, v2, v3 = T.axis.remap("SSSS", [ax0, ax1, ax2, ax3]) offset = v2 * elems_per_group + v3 - shift = (offset % num_groups) * bits_stride + ( - offset // num_groups - ) * bits - B[v0, v1] = B[v0, v1] | ( - ((A[v0, v1] >> (bits * offset)) & mask) << shift - ) + shift = (offset % num_groups) * bits_stride + (offset // num_groups) * bits + B[v0, v1] = B[v0, v1] | (((A[v0, v1] >> (bits * offset)) & mask) << shift) for ax0, ax1 in T.grid(N, QK): with T.block("B"): @@ -152,13 +134,10 @@ def interleave_weight_f16_1b( | B_tmp_4[v0, v1] | B_tmp_5[v0, v1] | B_tmp_6[v0, v1] - | B_tmp_7[v0, v1] - ) + | B_tmp_7[v0, v1]) @T.prim_func - def interleave_weight_int8_1b( - A: T.Buffer((N, QK), "int32"), B: T.Buffer((N, QK), "int32") - ): + def interleave_weight_int8_1b(A: T.Buffer((N, QK), "int32"), B: T.Buffer((N, QK), "int32")): B_tmp_1 = T.alloc_buffer((N, QK), "int32", scope="local") B_tmp_2 = T.alloc_buffer((N, QK), "int32", scope="local") B_tmp_3 = T.alloc_buffer((N, QK), "int32", scope="local") @@ -168,12 +147,8 @@ def interleave_weight_int8_1b( with T.block("B_tmp"): v0, v1, v2, v3 = T.axis.remap("SSSS", [ax0, ax1, ax2, ax3]) offset = v2 * elems_per_group + v3 - shift = (offset % num_groups) * bits_stride + ( - offset // num_groups - ) * bits - B[v0, v1] = B[v0, v1] | ( - ((A[v0, v1] >> (bits * offset)) & mask) << shift - ) + shift = (offset % num_groups) * bits_stride + (offset // num_groups) * bits + B[v0, v1] = B[v0, v1] | (((A[v0, v1] >> (bits * offset)) & mask) << shift) for ax0, ax1 in T.grid(N, QK): with T.block("B"): @@ -188,8 +163,7 @@ def interleave_weight_int8_1b( | B_tmp_2[v0, v1] | B_tmp_3[v0, v1] | B_tmp_4[v0, v1] - | B_tmp_5[v0, v1] - ) + | B_tmp_5[v0, v1]) if target_dtype == "float16" and bits == 2: return interleave_weight_f16_2b @@ -207,7 +181,7 @@ def test_lop3_interleave_weight(): K = 16 target_dtype = "float16" torch.manual_seed(0) - uint_max = 2 ** (source_nbits) - 1 + uint_max = 2**(source_nbits) - 1 raw_data = torch.randint(0, uint_max, (N, K), dtype=torch.int8).cpu().numpy() compressed_b = general_compress_to_int8(raw_data, source_nbits) interleaved_weight = interleave_weight(compressed_b, source_nbits, target_dtype) From bfedeaa813c269330ea332cce9938e05eac809ec Mon Sep 17 00:00:00 2001 From: LeiWang1999 Date: Mon, 8 Jul 2024 04:50:49 +0000 Subject: [PATCH 17/17] Refactor test_int4b_fp16_convert.py for improved readability and maintainability --- testing/python/type_conversion/test_int4b_fp16_convert.py | 1 - 1 file changed, 1 deletion(-) diff --git a/testing/python/type_conversion/test_int4b_fp16_convert.py b/testing/python/type_conversion/test_int4b_fp16_convert.py index 2af765047..3a58a47e1 100644 --- a/testing/python/type_conversion/test_int4b_fp16_convert.py +++ b/testing/python/type_conversion/test_int4b_fp16_convert.py @@ -5,7 +5,6 @@ import torch import numpy as np from tvm.script import tir as T -import numpy as np def general_compress_to_int8(lowprecision_weight, source_bits=4):