diff --git a/src/relax/op/op_common.cc b/src/relax/op/op_common.cc index 61485b09112b..a019b87f3a2b 100644 --- a/src/relax/op/op_common.cc +++ b/src/relax/op/op_common.cc @@ -22,6 +22,7 @@ #include #include +#include namespace tvm { namespace relax { @@ -108,10 +109,10 @@ ffi::Array GetTensorStructInfoFromTuple(const Call& call, cons return tensor_sinfo; } -ffi::Optional> InferBinaryBroadcastShape( - const Call& call, const BlockBuilder& ctx, const ffi::Array& x1_shape, - const ffi::Array& x2_shape) { - arith::Analyzer* analyzer = ctx->GetAnalyzer(); +BinaryBroadcastShapeInferResult InferBinaryBroadcastShape(arith::Analyzer* analyzer, + const ffi::Array& x1_shape, + const ffi::Array& x2_shape) { + BinaryBroadcastShapeInferResult result; int x1_ndim = x1_shape.size(); int x2_ndim = x2_shape.size(); int max_ndim = std::max(x1_ndim, x2_ndim); @@ -132,20 +133,45 @@ ffi::Optional> InferBinaryBroadcastShape( } else if (analyzer->CanProveEqual(dim0, dim1)) { output_shape.push_back(dim0); } else if (int_dim0 && int_dim1 && int_dim0->value != int_dim1->value) { - ctx->ReportFatal(Diagnostic::Error(call) - << "In " << call->op << ", the first input shape at dim " << x1_ndim - i - << " is " << dim0 << " and the second input shape at dim " << x2_ndim - i - << " is " << dim1 << ", which are not broadcastable."); + result.status = BinaryBroadcastShapeInferResult::Status::kConflict; + result.message = [&]() { + std::ostringstream os; + os << "the first input shape at dim " << x1_ndim - i << " is " << dim0 + << " and the second input shape at dim " << x2_ndim - i << " is " << dim1 + << ", which are not broadcastable."; + return ffi::String(os.str()); + }(); + return result; } else { - // Use simple fallback when shape mismatch. - return std::nullopt; + result.status = BinaryBroadcastShapeInferResult::Status::kUnknown; + return result; } } auto& longer_shape = (x1_ndim > x2_ndim) ? x1_shape : x2_shape; for (; i <= max_ndim; ++i) { output_shape.push_back(longer_shape[max_ndim - i]); } - return ffi::Array(output_shape.rbegin(), output_shape.rend()); + result.status = BinaryBroadcastShapeInferResult::Status::kSuccess; + result.shape = ffi::Array(output_shape.rbegin(), output_shape.rend()); + return result; +} + +ffi::Optional> InferBinaryBroadcastShape( + const Call& call, const BlockBuilder& ctx, const ffi::Array& x1_shape, + const ffi::Array& x2_shape) { + auto infer_result = InferBinaryBroadcastShape(ctx->GetAnalyzer(), x1_shape, x2_shape); + if (infer_result.status == BinaryBroadcastShapeInferResult::Status::kConflict) { + TVM_FFI_ICHECK(infer_result.message.has_value()); + ctx->ReportFatal(Diagnostic::Error(call) + << "In " << call->op << ", " << infer_result.message.value()); + } else if (infer_result.status == BinaryBroadcastShapeInferResult::Status::kSuccess) { + TVM_FFI_ICHECK(infer_result.shape.has_value()); + return infer_result.shape.value(); + } else { + // Unknown status, use simple fallback when shape mismatch. + return std::nullopt; + } + TVM_FFI_UNREACHABLE(); } std::vector NormalizeAxes(const Call& call, const BlockBuilder& ctx, int ndim, diff --git a/src/relax/op/op_common.h b/src/relax/op/op_common.h index 774eccfd58dd..6f7de974cbe6 100644 --- a/src/relax/op/op_common.h +++ b/src/relax/op/op_common.h @@ -387,6 +387,36 @@ inline ffi::Optional InferBinaryArithOpOutVDevice(const Call& call, return lhs_vdevice; } +/*! \brief Result of binary broadcast shape inference without diagnostic context. */ +struct BinaryBroadcastShapeInferResult { + enum class Status { + /*! \brief Broadcast output shape is known. */ + kSuccess, + /*! \brief Shapes may be broadcastable but cannot be proved symbolically. */ + kUnknown, + /*! \brief Concrete shapes are not broadcastable. */ + kConflict, + }; + + /*! \brief Inference status. */ + Status status = Status::kUnknown; + /*! \brief Broadcasted shape if status is kSuccess. */ + ffi::Optional> shape; + /*! \brief Human-readable conflict description if status is kConflict. */ + ffi::Optional message; +}; + +/*! + * \brief Infer the output shape for binary broadcast operators. + * \param analyzer The arithmetic analyzer used to prove shape equality. + * \param x1_shape The shape of the first operand. + * \param x2_shape The shape of the second operand. + * \return Inference status and broadcasted shape, or a conflict message. + */ +BinaryBroadcastShapeInferResult InferBinaryBroadcastShape(arith::Analyzer* analyzer, + const ffi::Array& x1_shape, + const ffi::Array& x2_shape); + /*! * \brief Infer the output shape for binary broadcast operators. * \param call The context Call to the operator. diff --git a/src/relax/transform/adjust_matmul_order.cc b/src/relax/transform/adjust_matmul_order.cc index 9ea47aa64844..012c8ce5b71a 100644 --- a/src/relax/transform/adjust_matmul_order.cc +++ b/src/relax/transform/adjust_matmul_order.cc @@ -34,6 +34,7 @@ #include #include +#include "../op/op_common.h" #include "../op/tensor/linear_algebra.h" #include "../op/tensor/manipulate.h" @@ -41,6 +42,27 @@ namespace tvm { namespace relax { namespace { + +ffi::Array GetBatchPrefix(const ffi::Array& shape) { + if (shape.size() <= 2) return {}; + return {shape.begin(), shape.end() - 2}; +} + +PrimExpr ProductDims(const ffi::Array& dims) { + PrimExpr product = IntImm(DataType::Int(64), 1); + for (const auto& dim : dims) product = product * dim; + return product; +} + +ffi::Optional> InferBatchedMatmulBroadcastPrefix( + arith::Analyzer* analyzer, const ffi::Array& x1, const ffi::Array& x2) { + auto infer_result = InferBinaryBroadcastShape(analyzer, x1, x2); + if (infer_result.status == BinaryBroadcastShapeInferResult::Status::kSuccess) { + return infer_result.shape; + } + return std::nullopt; +} + std::tuple)>> CreatePatterns( const Function& func) { auto compile_time_arr = ComputableAtCompileTime(func); @@ -141,20 +163,46 @@ std::tuple)>> auto shape_b = opt_shape_b.value(); auto shape_c = opt_shape_c.value(); + auto permute_last_two_dims = [&](Expr expr) -> Expr { + auto opt_shape = get_shape(expr); + if (!opt_shape) return expr; + + size_t ndim = opt_shape.value().size(); + TVM_FFI_ICHECK_GE(ndim, 2); + + ffi::Optional> axes; + + if (ndim == 2) { + // Pass none axes to permute_dims for simple transpose of 2D tensors. + axes = std::nullopt; + } else { + ffi::Array axes_array; + for (size_t i = 0; i < ndim; ++i) axes_array.push_back(i); + axes_array.Set(ndim - 1, ndim - 2); + axes_array.Set(ndim - 2, ndim - 1); + axes = ffi::Optional>(axes_array); + } + return permute_dims(std::move(expr), axes); + }; + + auto transpose_shape_last_two_dims = [&](ffi::Array& shape) { + PrimExpr last_dim_shape = shape[shape.size() - 1]; + shape.Set(shape.size() - 1, shape[shape.size() - 2]); + shape.Set(shape.size() - 2, last_dim_shape); + }; + if (matches.count(pat_permuted_matmul_on_lhs)) { - expr_a = permute_dims(expr_a, std::nullopt); - expr_b = permute_dims(expr_b, std::nullopt); - TVM_FFI_ICHECK_EQ(shape_a.size(), 2); - TVM_FFI_ICHECK_EQ(shape_b.size(), 2); - shape_a = {shape_a[1], shape_a[0]}; - shape_b = {shape_b[1], shape_b[0]}; + if (shape_a.size() < 2 || shape_b.size() < 2) return expr; + expr_a = permute_last_two_dims(expr_a); + expr_b = permute_last_two_dims(expr_b); + transpose_shape_last_two_dims(shape_a); + transpose_shape_last_two_dims(shape_b); } else if (matches.count(pat_permuted_matmul_on_rhs)) { - expr_b = permute_dims(expr_b, std::nullopt); - expr_c = permute_dims(expr_c, std::nullopt); - TVM_FFI_ICHECK_EQ(shape_b.size(), 2); - TVM_FFI_ICHECK_EQ(shape_c.size(), 2); - shape_b = {shape_b[1], shape_b[0]}; - shape_c = {shape_c[1], shape_c[0]}; + if (shape_b.size() < 2 || shape_c.size() < 2) return expr; + expr_b = permute_last_two_dims(expr_b); + expr_c = permute_last_two_dims(expr_c); + transpose_shape_last_two_dims(shape_b); + transpose_shape_last_two_dims(shape_c); } // If two of the three are compile-time, group those two values @@ -166,13 +214,7 @@ std::tuple)>> } // Otherwise, select the order that reduces the total number of - // operations required, assuming a naive matmul. - - // Matmul on LHS: ([N,R]*[R,M]) * [M,batch] - // Matmul on RHS: [N,R] * ([R,M]*[M,batch]) - // - // LHS first: `N*R*M + N*M*batch = N*M*(R+batch)` - // RHS first: `N*R*batch + R*M*batch = (N+M)*R*batch` + // operations required, assuming a naive matmul (see below). if (shape_a.size() == 1) { shape_a = {IntImm(shape_a[0].dtype(), 1), shape_a[0]}; @@ -192,21 +234,54 @@ std::tuple)>> shape_c = {shape_c[0], IntImm(shape_c[0].dtype(), 1)}; } - auto size_N = shape_a[shape_a.size() - 2]; - auto size_R = shape_a[shape_a.size() - 1]; - auto size_M = shape_c[shape_c.size() - 2]; - auto size_B = shape_c[shape_c.size() - 1]; - - auto ops_with_lhs_first = (size_R + size_B) * size_N * size_M; - auto ops_with_rhs_first = (size_M + size_N) * size_R * size_B; + PrimExpr size_N = shape_a[shape_a.size() - 2]; // row of A + PrimExpr size_R = shape_a[shape_a.size() - 1]; // col of A and row of B + PrimExpr size_M = shape_c[shape_c.size() - 2]; // row of C and col of B + PrimExpr size_B = shape_c[shape_c.size() - 1]; // col of C arith::Analyzer analyzer; + auto prefix_a = GetBatchPrefix(shape_a); + auto prefix_b = GetBatchPrefix(shape_b); + auto prefix_c = GetBatchPrefix(shape_c); + + auto opt_prefix_ab = InferBatchedMatmulBroadcastPrefix(&analyzer, prefix_a, prefix_b); + if (!opt_prefix_ab) return expr; + auto opt_prefix_bc = InferBatchedMatmulBroadcastPrefix(&analyzer, prefix_b, prefix_c); + if (!opt_prefix_bc) return expr; + auto opt_prefix_outer_lhs = + InferBatchedMatmulBroadcastPrefix(&analyzer, opt_prefix_ab.value(), prefix_c); + if (!opt_prefix_outer_lhs) return expr; + auto opt_prefix_outer_rhs = + InferBatchedMatmulBroadcastPrefix(&analyzer, prefix_a, opt_prefix_bc.value()); + if (!opt_prefix_outer_rhs) return expr; + + PrimExpr batch_ab = ProductDims(opt_prefix_ab.value()); + PrimExpr batch_bc = ProductDims(opt_prefix_bc.value()); + PrimExpr batch_outer_lhs = ProductDims(opt_prefix_outer_lhs.value()); + PrimExpr batch_outer_rhs = ProductDims(opt_prefix_outer_rhs.value()); + + // Compare naive matmul FLOPs for two evaluation orders of + // matmul(A, matmul(B, C)) vs matmul(matmul(A, B), C) + // + // Matrix dims (last two axes): A [N, R], B [R, M], C [M, B_last] + // Each matmul uses the broadcasted batch prefix of its operands. + // + // LHS first — matmul(matmul(A, B), C): + // batch_ab * N * R * M + batch_outer_lhs * N * M * B_last + PrimExpr ops_with_lhs_first = + batch_ab * size_N * size_R * size_M + batch_outer_lhs * size_N * size_M * size_B; + // RHS first — matmul(A, matmul(B, C)): + // batch_bc * R * M * B_last + batch_outer_rhs * N * R * B_last + PrimExpr ops_with_rhs_first = + batch_bc * size_R * size_M * size_B + batch_outer_rhs * size_N * size_R * size_B; + analyzer.rewrite_simplify.SetEnabledExtensions(static_cast( analyzer.rewrite_simplify.GetEnabledExtensions() | arith::RewriteSimplifier::Extension::kComparisonOfProductAndSum)); With func_attr_constraint(&analyzer, symbolic_var_constraints); With analyzer_constraint( - &analyzer, size_N > 0 && size_R > 0 && size_M > 0 && size_B > 0); + &analyzer, batch_ab > 0 && batch_bc > 0 && batch_outer_lhs > 0 && batch_outer_rhs > 0 && + size_N > 0 && size_R > 0 && size_M > 0 && size_B > 0); if (analyzer.CanProve(ops_with_lhs_first < ops_with_rhs_first)) { return matmul(matmul(expr_a, expr_b, DataType::Void()), expr_c, DataType::Void()); @@ -214,8 +289,7 @@ std::tuple)>> return matmul(expr_a, matmul(expr_b, expr_c, DataType::Void()), DataType::Void()); } - // If we cannot determine which order is best, keep the existing - // order. + // If we cannot determine which order is best, keep the existing order. return expr; }; diff --git a/tests/python/relax/test_transform_adjust_matmul_order.py b/tests/python/relax/test_transform_adjust_matmul_order.py index a086f3abdb8d..9600c97bdaac 100644 --- a/tests/python/relax/test_transform_adjust_matmul_order.py +++ b/tests/python/relax/test_transform_adjust_matmul_order.py @@ -17,8 +17,11 @@ import inspect +import numpy as np import pytest +import torch +import tvm import tvm.testing from tvm import relax from tvm.script import ir as I @@ -39,7 +42,13 @@ def test_compare(self): class TestLHS(Base): - """Prefer (x*A)*B instead of x*(A*B)""" + """Prefer (x*A)*B instead of x*(A*B) + + LHS first - (x*A)*B: + ops = 1*16*2 + 1*2*32 = 96 + RHS first - x*(A*B): + ops = 16*2*32 + 1*16*32 = 1536 + """ @I.ir_module class Before: @@ -67,7 +76,13 @@ def main( class TestRHS(Base): - """Prefer A*(B*x) instead of (A*B)*x""" + """Prefer A*(B*x) instead of (A*B)*x + + LHS first - (A*B)*x: + ops = 32*2*16 + 32*16*1 = 1536 + RHS first - A*(B*x): + ops = 2*16*1 + 32*2*1 = 96 + """ @I.ir_module class Before: @@ -163,6 +178,13 @@ class TestLHSDynamic(Base): This case appears when evaluating LoRA-tuned models with a dynamic rank. + + LHS first - (x*A)*B: + ops = 1*16*lora_r + 1*lora_r*32 = 48*lora_r + RHS first - x*(A*B): + ops = 16*lora_r*32 + 1*16*32 = 512*lora_r + 512 + + 48*lora_r can be proved to be less than 512*lora_r + 512, so the LHS first is preferred. """ @I.ir_module @@ -192,7 +214,15 @@ def main( class TestRHSDynamic(Base): - """Prefer A*(B*x) instead of (A*B)*x""" + """Prefer A*(B*x) instead of (A*B)*x + + LHS first - (A*B)*x: + ops = 32*lora_r*16 + 32*16*1 = 512*lora_r + 512 + RHS first - A*(B*x): + ops = lora_r*16*1 + 32*lora_r*1 = 48*lora_r + + 48*lora_r can be proved to be less than 512*lora_r + 512, so the RHS first is preferred. + """ @I.ir_module class Before: @@ -234,8 +264,27 @@ class TestIdempotentRHSDynamic(Base): Expected = TestRHSDynamic.Expected -class TestLHSDynamicWithBatch(Base): - """Prefer (x*A)*B instead of x*(A*B)""" +class TestDynamicWithBatchSymbolic1(Base): + """When both batch_size and lora_r are symbolic and it cannot be proven which + is cheaper, LHS or RHS, maintain the existing order. + + `Before` computes `x * (A * B)` with + `x: [batch_size, 1, 16]`, `A: [16, lora_r]`, `B: [lora_r, 32]`. + + RHS first - x * (A * B): + 16*lora_r*32 + batch_size*1*16*32 = 512*(lora_r + batch_size) + + LHS first - (x * A) * B: + batch_size*1*16*lora_r + batch_size*1*lora_r*32 = 48*batch_size*lora_r + + When `batch_size` and `lora_r` are known at compile-time: + - satisfy the inequality 48*batch_size*lora_r < 512*(lora_r + batch_size), + the LHS first is preferred. + - satisfy the inequality 512*(lora_r + batch_size) < 48*batch_size*lora_r, + the RHS first is preferred. + + Without bounds on `batch_size` and `lora_r`, neither side is provably cheaper. + """ @I.ir_module class Before: @@ -250,6 +299,31 @@ def main( out: R.Tensor([batch_size, 1, 32]) = R.matmul(x, weight) return out + Expected = Before + + +class TestDynamicWithBatchConcrete1LHSFirst(Base): + """With concrete shapes, LHS first is provably cheaper. + + batch_size=4, lora_r=16: + LHS first: 48*4*16 = 3072 + RHS first: 512*(16 + 4) = 10240 + """ + + @I.ir_module + class Before: + @R.function + def main( + x: R.Tensor(["batch_size", 1, 16]), + A: R.Tensor([16, "lora_r"]), + B: R.Tensor(["lora_r", 32]), + ) -> R.Tensor(["batch_size", 1, 32]): + batch_size = T.int64(4) + lora_r = T.int64(16) # noqa: F841 + weight: R.Tensor([16, 32]) = R.matmul(A, B) + out: R.Tensor([batch_size, 1, 32]) = R.matmul(x, weight) + return out + @I.ir_module class Expected: @R.function @@ -258,15 +332,71 @@ def main( A: R.Tensor([16, "lora_r"]), B: R.Tensor(["lora_r", 32]), ) -> R.Tensor(["batch_size", 1, 32]): - lora_r = T.int64() - batch_size = T.int64() - x: R.Tensor([batch_size, 1, lora_r]) = R.matmul(x, A) - x: R.Tensor([batch_size, 1, 32]) = R.matmul(x, B) - return x + batch_size = T.int64(4) + lora_r = T.int64(16) + weight: R.Tensor([batch_size, 1, lora_r]) = R.matmul(x, A) + out: R.Tensor([batch_size, 1, 32]) = R.matmul(weight, B) + return out -class TestRHSDynamicWithBatch(Base): - """Prefer A*(B*x) instead of (A*B)*x""" +class TestDynamicWithBatchConcrete1RHSFirst(Base): + """With concrete shapes, RHS first is provably cheaper. + + batch_size=64, lora_r=16: + LHS first: 48*64*16 = 49152 + RHS first: 512*(16 + 64) = 40960 + """ + + @I.ir_module + class Before: + @R.function + def main( + x: R.Tensor(["batch_size", 1, 16]), + A: R.Tensor([16, "lora_r"]), + B: R.Tensor(["lora_r", 32]), + ) -> R.Tensor(["batch_size", 1, 32]): + batch_size = T.int64(64) + lora_r = T.int64(16) + weight: R.Tensor([batch_size, 1, lora_r]) = R.matmul(x, A) + out: R.Tensor([batch_size, 1, 32]) = R.matmul(weight, B) + return out + + @I.ir_module + class Expected: + @R.function + def main( + x: R.Tensor(["batch_size", 1, 16]), + A: R.Tensor([16, "lora_r"]), + B: R.Tensor(["lora_r", 32]), + ) -> R.Tensor(["batch_size", 1, 32]): + batch_size = T.int64(64) + lora_r = T.int64(16) # noqa: F841 + weight: R.Tensor([16, 32]) = R.matmul(A, B) + out: R.Tensor([batch_size, 1, 32]) = R.matmul(x, weight) + return out + + +class TestDynamicWithBatchSymbolic2(Base): + """When both batch_size and lora_r are symbolic and it cannot be proven which + is cheaper, LHS or RHS, maintain the existing order. + + `Before` computes `(A * B) * x` with + `A: [32, lora_r]`, `B: [lora_r, 16]`, `x: [batch_size, 16, 1]`. + + LHS first - (A * B) * x: + 32*lora_r*16 + batch_size*32*16*1 = 512*(lora_r + batch_size) + + RHS first - A * (B * x): + batch_size*lora_r*16*1 + batch_size*32*lora_r*1 = 48*batch_size*lora_r + + When `batch_size` and `lora_r` are known at compile-time: + - satisfy the inequality 48*batch_size*lora_r < 512*(lora_r + batch_size), + the RHS first is preferred. + - satisfy the inequality 512*(lora_r + batch_size) < 48*batch_size*lora_r, + the LHS first is preferred. + + Without bounds on `batch_size` and `lora_r`, neither side is provably cheaper. + """ @I.ir_module class Before: @@ -281,6 +411,31 @@ def main( out: R.Tensor([batch_size, 32, 1]) = R.matmul(weight, x) return out + Expected = Before + + +class TestDynamicWithBatchConcrete2RHSFirst(Base): + """With concrete shapes, RHS first is provably cheaper. + + batch_size=4, lora_r=16: + RHS first: 48*4*16 = 3072 + LHS first: 512*(16 + 4) = 10240 + """ + + @I.ir_module + class Before: + @R.function + def main( + x: R.Tensor(["batch_size", 16, 1]), + A: R.Tensor([32, "lora_r"]), + B: R.Tensor(["lora_r", 16]), + ) -> R.Tensor(["batch_size", 32, 1]): + batch_size = T.int64(4) + lora_r = T.int64(16) # noqa: F841 + weight: R.Tensor([32, 16]) = R.matmul(A, B) + out: R.Tensor([batch_size, 32, 1]) = R.matmul(weight, x) + return out + @I.ir_module class Expected: @R.function @@ -289,11 +444,48 @@ def main( A: R.Tensor([32, "lora_r"]), B: R.Tensor(["lora_r", 16]), ) -> R.Tensor(["batch_size", 32, 1]): - lora_r = T.int64() - batch_size = T.int64() - x: R.Tensor([batch_size, lora_r, 1]) = R.matmul(B, x) - x: R.Tensor([batch_size, 32, 1]) = R.matmul(A, x) - return x + batch_size = T.int64(4) + lora_r = T.int64(16) + weight: R.Tensor([batch_size, lora_r, 1]) = R.matmul(B, x) + out: R.Tensor([batch_size, 32, 1]) = R.matmul(A, weight) + return out + + +class TestDynamicWithBatchConcrete2LHSFirst(Base): + """With concrete shapes, LHS first is provably cheaper. + + batch_size=64, lora_r=16: + RHS first: 48*64*16 = 49152 + LHS first: 512*(16 + 64) = 40960 + """ + + @I.ir_module + class Before: + @R.function + def main( + x: R.Tensor(["batch_size", 16, 1]), + A: R.Tensor([32, "lora_r"]), + B: R.Tensor(["lora_r", 16]), + ) -> R.Tensor(["batch_size", 32, 1]): + batch_size = T.int64(64) + lora_r = T.int64(16) + weight: R.Tensor([batch_size, lora_r, 1]) = R.matmul(B, x) + out: R.Tensor([batch_size, 32, 1]) = R.matmul(A, weight) + return out + + @I.ir_module + class Expected: + @R.function + def main( + x: R.Tensor(["batch_size", 16, 1]), + A: R.Tensor([32, "lora_r"]), + B: R.Tensor(["lora_r", 16]), + ) -> R.Tensor(["batch_size", 32, 1]): + batch_size = T.int64(64) + lora_r = T.int64(16) # noqa: F841 + weight: R.Tensor([32, 16]) = R.matmul(A, B) + out: R.Tensor([batch_size, 32, 1]) = R.matmul(weight, x) + return out class TestNoOpForFullyDynamicOnLHS(Base): @@ -353,6 +545,11 @@ class TestRHSPermuteDims(Base): """Prefer (x*A)*B instead of x*(A*B) Like `TestRHS`, but the weights on the RHS are transposed. + + Before: x * (BT * AT) + ops = 16*2*32 + 1*16*32 = 1536 + After: (x * BT) * AT + ops = 1*16*2 + 1*2*32 = 96 """ @I.ir_module @@ -388,6 +585,13 @@ class TestRHSPermuteDimsDynamic(Base): Like `TestRHSPermuteDims`, but the weights on the RHS have a dynamic shape. + + Before: x * (BT * AT) + ops = 16*lora_r*32 + 1*16*32 = 512*lora_r + 512 + After: (x * BT) * AT + ops = 1*16*lora_r + 1*lora_r*32 = 48*lora_r + + 48*lora_r can be proved to be less than 512*lora_r + 512, so the After is preferred. """ @I.ir_module @@ -433,15 +637,15 @@ class TestRHSPermuteDimsWithDynamicBatch(Base): ops_left_to_right = (batch_size + lora_r)*4096*4096 ops_right_to_left = (4096 + 4096)*batch_size*lora_r - Without an upper bound on `lora_r`, we cannot prove which of these - is the preferred execution order. With the upper bound, TVM can - determine the preferred order using the following arithmethic - reasoning. + Without an upper bound on batch_size and`lora_r`, we cannot prove which + of these is the preferred execution order. - (batch_size + lora_r)*4096*4096 < (4096 + 4096)*batch_size*lora_r - (batch_size + lora_r)*2048 < batch_size*lora_r - 1/batch_size + 1/lora_r < 1/2048 + With the upper bound, TVM can determine the preferred order using + the following arithmetic reasoning. + (batch_size + lora_r)*4096*4096 > (4096 + 4096)*batch_size*lora_r + (batch_size + lora_r)*2048 > batch_size*lora_r + 1/batch_size + 1/lora_r > 1/2048 """ @I.ir_module @@ -452,7 +656,12 @@ def main( A: R.Tensor([4096, "lora_r"]), B: R.Tensor(["lora_r", 4096]), ) -> R.Tensor(["batch_size", 4096]): - R.func_attr({"tir_var_upper_bound": {"lora_r": 2048}}) + R.func_attr( + { + "tir_var_upper_bound": {"lora_r": 2048, "batch_size": 2048}, + } + ) + lora_r = T.int64() # noqa: F841 batch_size = T.int64() linear_weight: R.Tensor([4096, 4096]) = R.matmul(A, B) matmul_weight: R.Tensor([4096, 4096]) = R.permute_dims(linear_weight) @@ -467,7 +676,11 @@ def main( A: R.Tensor([4096, "lora_r"]), B: R.Tensor(["lora_r", 4096]), ) -> R.Tensor(["batch_size", 4096]): - R.func_attr({"tir_var_upper_bound": {"lora_r": 2048}}) + R.func_attr( + { + "tir_var_upper_bound": {"lora_r": 2048, "batch_size": 2048}, + } + ) lora_r = T.int64() batch_size = T.int64() B_transpose = R.permute_dims(B) @@ -482,6 +695,11 @@ class TestRHSPermuteDimsDynamicWithSquareMatrix(Base): Like `TestRHSPermuteDims`, but the weights on the RHS have a dynamic shape. + + Before: x * (BT * AT) + ops = 32*lora_r*32 + 1*32*32 = 1024*lora_r + 1024 + After: (x * BT) * AT + ops = 1*32*lora_r + 1*lora_r*32 = 64*lora_r """ @I.ir_module @@ -513,5 +731,143 @@ def main( return x +class TestBatchedBroadcastPreferLHSFirst(Base): + """Use broadcasted batch prefix per matmul, not independent prefix products. + + Example with broadcast batch axes: A:[2,1,1], B:[2,1,2], C:[2,2,3]. + + LHS first: (A * B) * C + ops = 2*1*1*2 + 2*1*2*3 = 16 + RHS first: A * (B * C) + ops = 2*1*2*3 + 2*1*1*3 = 18 + """ + + @I.ir_module + class Before: + @R.function + def main( + A: R.Tensor([2, 1, 1]), + B: R.Tensor([2, 1, 2]), + C: R.Tensor([2, 2, 3]), + ) -> R.Tensor([2, 1, 3]): + out: R.Tensor([2, 1, 3]) = R.matmul(A, R.matmul(B, C)) + return out + + @I.ir_module + class Expected: + @R.function + def main( + A: R.Tensor([2, 1, 1]), + B: R.Tensor([2, 1, 2]), + C: R.Tensor([2, 2, 3]), + ) -> R.Tensor([2, 1, 3]): + temp: R.Tensor([2, 1, 2]) = R.matmul(A, B) + out: R.Tensor([2, 1, 3]) = R.matmul(temp, C) + return out + + +class TestBatchedSharedPrefixPreferLHSFirst(Base): + """All operands share a nontrivial batch prefix [2, 3]. + + Shapes: A:[2,3,4,5], B:[2,3,5,6], C:[2,3,6,7] + + LHS first: + ops = 6*4*5*6 + 6*4*6*7 = 1728 + RHS first: + ops = 6*5*6*7 + 6*4*5*7 = 2100 + """ + + @I.ir_module + class Before: + @R.function + def main( + A: R.Tensor([2, 3, 4, 5]), + B: R.Tensor([2, 3, 5, 6]), + C: R.Tensor([2, 3, 6, 7]), + ) -> R.Tensor([2, 3, 4, 7]): + out: R.Tensor([2, 3, 4, 7]) = R.matmul(A, R.matmul(B, C)) + return out + + @I.ir_module + class Expected: + @R.function + def main( + A: R.Tensor([2, 3, 4, 5]), + B: R.Tensor([2, 3, 5, 6]), + C: R.Tensor([2, 3, 6, 7]), + ) -> R.Tensor([2, 3, 4, 7]): + temp: R.Tensor([2, 3, 4, 6]) = R.matmul(A, B) + out: R.Tensor([2, 3, 4, 7]) = R.matmul(temp, C) + return out + + +class TestAdjustMatmulOrderAttentionBlock: + """AdjustMatmulOrder preserves numerics on a batched attention block. + + Covers ND `permute_dims` (swap last two axes) inside `matmul(q, kt)`, + regression for issue #19576. + """ + + def _build_attention_module(self, batch, seq, dim): + """Minimal batched attention block exercising ND permute_dims + matmul.""" + bb = relax.BlockBuilder() + x = relax.Var("x", relax.TensorStructInfo((batch, seq, dim), "float32")) + wq = relax.Var("wq", relax.TensorStructInfo((dim, dim), "float32")) + wk = relax.Var("wk", relax.TensorStructInfo((dim, dim), "float32")) + wv = relax.Var("wv", relax.TensorStructInfo((dim, dim), "float32")) + wo = relax.Var("wo", relax.TensorStructInfo((dim, dim), "float32")) + with bb.function("main", [x, wq, wk, wv, wo]): + with bb.dataflow(): + q = bb.emit(relax.op.matmul(x, wq)) + k = bb.emit(relax.op.matmul(x, wk)) + v = bb.emit(relax.op.matmul(x, wv)) + kt = bb.emit(relax.op.permute_dims(k, axes=[0, 2, 1])) + scores = bb.emit(relax.op.matmul(q, kt)) + scale = bb.emit(relax.const(1.0 / np.sqrt(dim), "float32")) + scores = bb.emit(relax.op.multiply(scores, scale)) + attn = bb.emit(relax.op.nn.softmax(scores, axis=-1)) + out = bb.emit(relax.op.matmul(attn, v)) + proj = bb.emit_output(relax.op.matmul(out, wo)) + bb.emit_func_output(proj) + return bb.finalize() + + def _run_relax_main(self, mod, inputs): + exe = relax.build(mod, target="llvm") + vm = relax.VirtualMachine(exe, device=tvm.cpu()) + args = [tvm.runtime.tensor(arr, device=tvm.cpu()) for arr in inputs] + return vm["main"](*args).numpy() + + def _torch_attention_ref(self, x_np, w_np, dim): + x = torch.from_numpy(x_np) + w = torch.from_numpy(w_np) + with torch.no_grad(): + q = torch.matmul(x, w) + k = torch.matmul(x, w) + v = torch.matmul(x, w) + scores = torch.matmul(q, k.transpose(-2, -1)) + scores = scores * (1.0 / np.sqrt(dim)) + attn = torch.nn.functional.softmax(scores, dim=-1) + out = torch.matmul(attn, v) + out = torch.matmul(out, w) + return out.detach().numpy() + + @pytest.mark.parametrize("batch,seq,dim", [(2, 16, 64)]) + def test_attention_block_numerics(self, batch, seq, dim): + mod = self._build_attention_module(batch, seq, dim) + mod_opt = relax.transform.AdjustMatmulOrder()(mod) + + x_np = np.random.randn(batch, seq, dim).astype("float32") + w_np = np.random.randn(dim, dim).astype("float32") + inputs = [x_np, w_np, w_np, w_np, w_np] + + ref = self._torch_attention_ref(x_np, w_np, dim) + out_before = self._run_relax_main(mod, inputs) + out_after = self._run_relax_main(mod_opt, inputs) + + tvm.testing.assert_allclose(out_before, ref, rtol=1e-3, atol=1e-3) + tvm.testing.assert_allclose(out_after, ref, rtol=1e-3, atol=1e-3) + tvm.testing.assert_allclose(out_before, out_after, rtol=1e-5, atol=1e-5) + + if __name__ == "__main__": tvm.testing.main()