From 1835da3ff5f0ccb22c984f0645673a545090eb47 Mon Sep 17 00:00:00 2001 From: ConvolutedDog Date: Mon, 1 Jun 2026 02:17:55 +0800 Subject: [PATCH 1/2] [Fix][Relax] Support ND batched matmul chains in AdjustMatmulOrder pass Fix a crash (https://github.com/apache/tvm/issues/19576) when AdjustMatmulOrder encounters mixed-dimension matmul chains common in transformer models (e.g. matmul(attn_output[B,S,D], W_o[D,D])). The pass previously assumed all operands in a chained rewrite were 2D and asserted shape_c.size() == 2, failing on 3D intermediate results. Changes: - Replace full 2D transpose with permute_last_two_dims for permuted matmul patterns, swapping only the last two axes for ND tensors. - Remove hard ndim==2 checks in the permuted rewrite path. - Account for batch prefixes when comparing naive matmul FLOPs, so reorder decisions reflect batched vs. weight-only inner matmuls. - Skip reorder when neither evaluation order is provably cheaper. - Add regression tests for symbolic/concrete batched LoRA shapes. - Add a numerics test covering a minimal attention block with ND permute_dims. --- src/relax/transform/adjust_matmul_order.cc | 99 +++++-- .../test_transform_adjust_matmul_order.py | 259 +++++++++++++++++- 2 files changed, 317 insertions(+), 41 deletions(-) diff --git a/src/relax/transform/adjust_matmul_order.cc b/src/relax/transform/adjust_matmul_order.cc index 9ea47aa64844..efdacb4e0e83 100644 --- a/src/relax/transform/adjust_matmul_order.cc +++ b/src/relax/transform/adjust_matmul_order.cc @@ -141,20 +141,44 @@ std::tuple)>> auto shape_b = opt_shape_b.value(); auto shape_c = opt_shape_c.value(); + auto permute_last_two_dims = [&](Expr expr) -> Expr { + auto opt_shape = get_shape(expr); + if (!opt_shape) return expr; + + size_t ndim = opt_shape.value().size(); + TVM_FFI_ICHECK_GE(ndim, 2); + + ffi::Optional> axes; + + if (ndim == 2) { + // Pass none axes to permute_dims for simple transpose of 2D tensors. + axes = std::nullopt; + } else { + ffi::Array axes_array; + for (size_t i = 0; i < ndim; ++i) axes_array.push_back(i); + axes_array.Set(ndim - 1, ndim - 2); + axes_array.Set(ndim - 2, ndim - 1); + axes = ffi::Optional>(axes_array); + } + return permute_dims(std::move(expr), axes); + }; + + auto transpose_shape_last_two_dims = [&](ffi::Array& shape) { + PrimExpr last_dim_shape = shape[shape.size() - 1]; + shape.Set(shape.size() - 1, shape[shape.size() - 2]); + shape.Set(shape.size() - 2, last_dim_shape); + }; + if (matches.count(pat_permuted_matmul_on_lhs)) { - expr_a = permute_dims(expr_a, std::nullopt); - expr_b = permute_dims(expr_b, std::nullopt); - TVM_FFI_ICHECK_EQ(shape_a.size(), 2); - TVM_FFI_ICHECK_EQ(shape_b.size(), 2); - shape_a = {shape_a[1], shape_a[0]}; - shape_b = {shape_b[1], shape_b[0]}; + expr_a = permute_last_two_dims(expr_a); + expr_b = permute_last_two_dims(expr_b); + transpose_shape_last_two_dims(shape_a); + transpose_shape_last_two_dims(shape_b); } else if (matches.count(pat_permuted_matmul_on_rhs)) { - expr_b = permute_dims(expr_b, std::nullopt); - expr_c = permute_dims(expr_c, std::nullopt); - TVM_FFI_ICHECK_EQ(shape_b.size(), 2); - TVM_FFI_ICHECK_EQ(shape_c.size(), 2); - shape_b = {shape_b[1], shape_b[0]}; - shape_c = {shape_c[1], shape_c[0]}; + expr_b = permute_last_two_dims(expr_b); + expr_c = permute_last_two_dims(expr_c); + transpose_shape_last_two_dims(shape_b); + transpose_shape_last_two_dims(shape_c); } // If two of the three are compile-time, group those two values @@ -166,13 +190,7 @@ std::tuple)>> } // Otherwise, select the order that reduces the total number of - // operations required, assuming a naive matmul. - - // Matmul on LHS: ([N,R]*[R,M]) * [M,batch] - // Matmul on RHS: [N,R] * ([R,M]*[M,batch]) - // - // LHS first: `N*R*M + N*M*batch = N*M*(R+batch)` - // RHS first: `N*R*batch + R*M*batch = (N+M)*R*batch` + // operations required, assuming a naive matmul (see below). if (shape_a.size() == 1) { shape_a = {IntImm(shape_a[0].dtype(), 1), shape_a[0]}; @@ -192,13 +210,41 @@ std::tuple)>> shape_c = {shape_c[0], IntImm(shape_c[0].dtype(), 1)}; } - auto size_N = shape_a[shape_a.size() - 2]; - auto size_R = shape_a[shape_a.size() - 1]; - auto size_M = shape_c[shape_c.size() - 2]; - auto size_B = shape_c[shape_c.size() - 1]; + PrimExpr size_N = shape_a[shape_a.size() - 2]; // row of A + PrimExpr size_R = shape_a[shape_a.size() - 1]; // col of A and row of B + PrimExpr size_M = shape_c[shape_c.size() - 2]; // row of C and col of B + PrimExpr size_B = shape_c[shape_c.size() - 1]; // col of C + + auto calculate_batch = [](ffi::Array& shape) { + PrimExpr batch = 1; + for (size_t i = 0; i < shape.size() - 2; ++i) { + batch *= shape[i]; + } + return batch; + }; + + PrimExpr batch_A = calculate_batch(shape_a); + PrimExpr batch_B = calculate_batch(shape_b); + PrimExpr batch_C = calculate_batch(shape_c); - auto ops_with_lhs_first = (size_R + size_B) * size_N * size_M; - auto ops_with_rhs_first = (size_M + size_N) * size_R * size_B; + // Compare naive matmul FLOPs for two evaluation orders of + // matmul(A, matmul(B, C)) vs matmul(matmul(A, B), C) + // + // Matrix dims (last two axes of each operand): + // A: [N, R] B: [R, M] C: [M, B_last] + // Batch prefixes (product of all leading axes): + // batch_A, batch_B, batch_C + // + // LHS first — matmul(matmul(A, B), C): + // inner matmul(A, B): batch_A * batch_B * N * R * M + // outer matmul(., C): batch_A * batch_B * batch_C * N * M * B_last + // total: batch_A * batch_B * N * M * (R + batch_C * B_last) + PrimExpr ops_with_lhs_first = (size_R + batch_C * size_B) * size_N * size_M * batch_A * batch_B; + // RHS first — matmul(A, matmul(B, C)): + // inner matmul(B, C): batch_B * batch_C * R * M * B_last + // outer matmul(A, .): batch_A * batch_B * batch_C * N * R * B_last + // total: batch_B * batch_C * R * B_last * (M + batch_A * N) + PrimExpr ops_with_rhs_first = (size_M + batch_A * size_N) * size_R * size_B * batch_B * batch_C; arith::Analyzer analyzer; analyzer.rewrite_simplify.SetEnabledExtensions(static_cast( @@ -214,8 +260,7 @@ std::tuple)>> return matmul(expr_a, matmul(expr_b, expr_c, DataType::Void()), DataType::Void()); } - // If we cannot determine which order is best, keep the existing - // order. + // If we cannot determine which order is best, keep the existing order. return expr; }; diff --git a/tests/python/relax/test_transform_adjust_matmul_order.py b/tests/python/relax/test_transform_adjust_matmul_order.py index a086f3abdb8d..6adf1184581b 100644 --- a/tests/python/relax/test_transform_adjust_matmul_order.py +++ b/tests/python/relax/test_transform_adjust_matmul_order.py @@ -17,8 +17,11 @@ import inspect +import numpy as np import pytest +import torch +import tvm import tvm.testing from tvm import relax from tvm.script import ir as I @@ -234,8 +237,26 @@ class TestIdempotentRHSDynamic(Base): Expected = TestRHSDynamic.Expected -class TestLHSDynamicWithBatch(Base): - """Prefer (x*A)*B instead of x*(A*B)""" +class TestDynamicWithBatchSymbolic1(Base): + """Keep existing order when batch_size and lora_r are both symbolic. + + Before computes `x @ (A @ B)` with + `x: [batch_size, 1, 16]`, `A: [16, lora_r]`, `B: [lora_r, 32]`. + + RHS first (fuse A@B once, no batch on inner matmul): + 16*lora_r*32 + batch_size*1*16*32 = 512*(lora_r + batch_size) + + LHS first (both matmuls scale with batch_size): + batch_size*1*16*lora_r + batch_size*1*lora_r*32 = 48*batch_size*lora_r + + When `batch_size` and `lora_r` are known at compile-time: + - satisfy the inequality 48*batch_size*lora_r < 512*(lora_r + batch_size), + the LHS first is preferred. + - satisfy the inequality 512*(lora_r + batch_size) < 48*batch_size*lora_r, + the RHS first is preferred. + + Without bounds on `batch_size` and `lora_r`, neither side is provably cheaper. + """ @I.ir_module class Before: @@ -250,6 +271,31 @@ def main( out: R.Tensor([batch_size, 1, 32]) = R.matmul(x, weight) return out + Expected = Before + + +class TestDynamicWithBatchConcrete1LHSFirst(Base): + """With concrete shapes, LHS first is provably cheaper. + + batch_size=4, lora_r=16: + LHS first: 48*4*16 = 3072 + RHS first: 512*(16 + 4) = 10240 + """ + + @I.ir_module + class Before: + @R.function + def main( + x: R.Tensor(["batch_size", 1, 16]), + A: R.Tensor([16, "lora_r"]), + B: R.Tensor(["lora_r", 32]), + ) -> R.Tensor(["batch_size", 1, 32]): + batch_size = T.int64(4) + lora_r = T.int64(16) # noqa: F841 + weight: R.Tensor([16, 32]) = R.matmul(A, B) + out: R.Tensor([batch_size, 1, 32]) = R.matmul(x, weight) + return out + @I.ir_module class Expected: @R.function @@ -258,15 +304,70 @@ def main( A: R.Tensor([16, "lora_r"]), B: R.Tensor(["lora_r", 32]), ) -> R.Tensor(["batch_size", 1, 32]): - lora_r = T.int64() - batch_size = T.int64() - x: R.Tensor([batch_size, 1, lora_r]) = R.matmul(x, A) - x: R.Tensor([batch_size, 1, 32]) = R.matmul(x, B) - return x + batch_size = T.int64(4) + lora_r = T.int64(16) + weight: R.Tensor([batch_size, 1, lora_r]) = R.matmul(x, A) + out: R.Tensor([batch_size, 1, 32]) = R.matmul(weight, B) + return out -class TestRHSDynamicWithBatch(Base): - """Prefer A*(B*x) instead of (A*B)*x""" +class TestDynamicWithBatchConcrete1RHSFirst(Base): + """With concrete shapes, RHS first is provably cheaper. + + batch_size=64, lora_r=16: + LHS first: 48*64*16 = 49152 + RHS first: 512*(16 + 64) = 40960 + """ + + @I.ir_module + class Before: + @R.function + def main( + x: R.Tensor(["batch_size", 1, 16]), + A: R.Tensor([16, "lora_r"]), + B: R.Tensor(["lora_r", 32]), + ) -> R.Tensor(["batch_size", 1, 32]): + batch_size = T.int64(64) + lora_r = T.int64(16) + weight: R.Tensor([batch_size, 1, lora_r]) = R.matmul(x, A) + out: R.Tensor([batch_size, 1, 32]) = R.matmul(weight, B) + return out + + @I.ir_module + class Expected: + @R.function + def main( + x: R.Tensor(["batch_size", 1, 16]), + A: R.Tensor([16, "lora_r"]), + B: R.Tensor(["lora_r", 32]), + ) -> R.Tensor(["batch_size", 1, 32]): + batch_size = T.int64(64) + lora_r = T.int64(16) # noqa: F841 + weight: R.Tensor([16, 32]) = R.matmul(A, B) + out: R.Tensor([batch_size, 1, 32]) = R.matmul(x, weight) + return out + + +class TestDynamicWithBatchSymbolic2(Base): + """Keep existing order when batch_size and lora_r are both symbolic. + + Before computes `(A @ B) @ x` with + `A: [32, lora_r]`, `B: [lora_r, 16]`, `x: [batch_size, 16, 1]`. + + LHS first (fuse A@B once, no batch on inner matmul): + 32*lora_r*16 + batch_size*32*16*1 = 512*(lora_r + batch_size) + + RHS first (both matmuls scale with batch_size): + batch_size*lora_r*16*1 + batch_size*32*lora_r*1 = 48*batch_size*lora_r + + When `batch_size` and `lora_r` are known at compile-time: + - satisfy the inequality 48*batch_size*lora_r < 512*(lora_r + batch_size), + the RHS first is preferred. + - satisfy the inequality 512*(lora_r + batch_size) < 48*batch_size*lora_r, + the LHS first is preferred. + + Without bounds on `batch_size` and `lora_r`, neither side is provably cheaper. + """ @I.ir_module class Before: @@ -281,6 +382,31 @@ def main( out: R.Tensor([batch_size, 32, 1]) = R.matmul(weight, x) return out + Expected = Before + + +class TestDynamicWithBatchConcrete2RHSFirst(Base): + """With concrete shapes, RHS first is provably cheaper. + + batch_size=4, lora_r=16: + RHS first: 48*4*16 = 3072 + LHS first: 512*(16 + 4) = 10240 + """ + + @I.ir_module + class Before: + @R.function + def main( + x: R.Tensor(["batch_size", 16, 1]), + A: R.Tensor([32, "lora_r"]), + B: R.Tensor(["lora_r", 16]), + ) -> R.Tensor(["batch_size", 32, 1]): + batch_size = T.int64(4) + lora_r = T.int64(16) # noqa: F841 + weight: R.Tensor([32, 16]) = R.matmul(A, B) + out: R.Tensor([batch_size, 32, 1]) = R.matmul(weight, x) + return out + @I.ir_module class Expected: @R.function @@ -289,11 +415,48 @@ def main( A: R.Tensor([32, "lora_r"]), B: R.Tensor(["lora_r", 16]), ) -> R.Tensor(["batch_size", 32, 1]): - lora_r = T.int64() - batch_size = T.int64() - x: R.Tensor([batch_size, lora_r, 1]) = R.matmul(B, x) - x: R.Tensor([batch_size, 32, 1]) = R.matmul(A, x) - return x + batch_size = T.int64(4) + lora_r = T.int64(16) + weight: R.Tensor([batch_size, lora_r, 1]) = R.matmul(B, x) + out: R.Tensor([batch_size, 32, 1]) = R.matmul(A, weight) + return out + + +class TestDynamicWithBatchConcrete2LHSFirst(Base): + """With concrete shapes, LHS first is provably cheaper. + + batch_size=64, lora_r=16: + RHS first: 48*64*16 = 49152 + LHS first: 512*(16 + 64) = 40960 + """ + + @I.ir_module + class Before: + @R.function + def main( + x: R.Tensor(["batch_size", 16, 1]), + A: R.Tensor([32, "lora_r"]), + B: R.Tensor(["lora_r", 16]), + ) -> R.Tensor(["batch_size", 32, 1]): + batch_size = T.int64(64) + lora_r = T.int64(16) + weight: R.Tensor([batch_size, lora_r, 1]) = R.matmul(B, x) + out: R.Tensor([batch_size, 32, 1]) = R.matmul(A, weight) + return out + + @I.ir_module + class Expected: + @R.function + def main( + x: R.Tensor(["batch_size", 16, 1]), + A: R.Tensor([32, "lora_r"]), + B: R.Tensor(["lora_r", 16]), + ) -> R.Tensor(["batch_size", 32, 1]): + batch_size = T.int64(64) + lora_r = T.int64(16) # noqa: F841 + weight: R.Tensor([32, 16]) = R.matmul(A, B) + out: R.Tensor([batch_size, 32, 1]) = R.matmul(weight, x) + return out class TestNoOpForFullyDynamicOnLHS(Base): @@ -513,5 +676,73 @@ def main( return x +class TestAdjustMatmulOrderAttentionBlock: + """AdjustMatmulOrder preserves numerics on a batched attention block. + + Covers ND `permute_dims` (swap last two axes) inside `matmul(q, kt)`, + regression for issue #19576. + """ + + def _build_attention_module(self, batch, seq, dim): + """Minimal batched attention block exercising ND permute_dims + matmul.""" + bb = relax.BlockBuilder() + x = relax.Var("x", relax.TensorStructInfo((batch, seq, dim), "float32")) + wq = relax.Var("wq", relax.TensorStructInfo((dim, dim), "float32")) + wk = relax.Var("wk", relax.TensorStructInfo((dim, dim), "float32")) + wv = relax.Var("wv", relax.TensorStructInfo((dim, dim), "float32")) + wo = relax.Var("wo", relax.TensorStructInfo((dim, dim), "float32")) + with bb.function("main", [x, wq, wk, wv, wo]): + with bb.dataflow(): + q = bb.emit(relax.op.matmul(x, wq)) + k = bb.emit(relax.op.matmul(x, wk)) + v = bb.emit(relax.op.matmul(x, wv)) + kt = bb.emit(relax.op.permute_dims(k, axes=[0, 2, 1])) + scores = bb.emit(relax.op.matmul(q, kt)) + scale = bb.emit(relax.const(1.0 / np.sqrt(dim), "float32")) + scores = bb.emit(relax.op.multiply(scores, scale)) + attn = bb.emit(relax.op.nn.softmax(scores, axis=-1)) + out = bb.emit(relax.op.matmul(attn, v)) + proj = bb.emit_output(relax.op.matmul(out, wo)) + bb.emit_func_output(proj) + return bb.finalize() + + def _run_relax_main(self, mod, inputs): + exe = relax.build(mod, target="llvm") + vm = relax.VirtualMachine(exe, device=tvm.cpu()) + args = [tvm.runtime.tensor(arr, device=tvm.cpu()) for arr in inputs] + return vm["main"](*args).numpy() + + def _torch_attention_ref(self, x_np, w_np, dim): + x = torch.from_numpy(x_np) + w = torch.from_numpy(w_np) + with torch.no_grad(): + q = torch.matmul(x, w) + k = torch.matmul(x, w) + v = torch.matmul(x, w) + scores = torch.matmul(q, k.transpose(-2, -1)) + scores = scores * (1.0 / np.sqrt(dim)) + attn = torch.nn.functional.softmax(scores, dim=-1) + out = torch.matmul(attn, v) + out = torch.matmul(out, w) + return out.detach().numpy() + + @pytest.mark.parametrize("batch,seq,dim", [(2, 16, 64)]) + def test_attention_block_numerics(self, batch, seq, dim): + mod = self._build_attention_module(batch, seq, dim) + mod_opt = relax.transform.AdjustMatmulOrder()(mod) + + x_np = np.random.randn(batch, seq, dim).astype("float32") + w_np = np.random.randn(dim, dim).astype("float32") + inputs = [x_np, w_np, w_np, w_np, w_np] + + ref = self._torch_attention_ref(x_np, w_np, dim) + out_before = self._run_relax_main(mod, inputs) + out_after = self._run_relax_main(mod_opt, inputs) + + tvm.testing.assert_allclose(out_before, ref, rtol=1e-3, atol=1e-3) + tvm.testing.assert_allclose(out_after, ref, rtol=1e-3, atol=1e-3) + tvm.testing.assert_allclose(out_before, out_after, rtol=1e-5, atol=1e-5) + + if __name__ == "__main__": tvm.testing.main() From 585d45c8fff8702f6f695396269f836329c8f089 Mon Sep 17 00:00:00 2001 From: ConvolutedDog Date: Tue, 2 Jun 2026 22:37:45 +0800 Subject: [PATCH 2/2] [Fix][Relax] Fix batched FLOP model in AdjustMatmulOrder Use broadcasted batch prefixes per matmul (via InferBinaryBroadcastShape) instead of multiplying independent leading-axis products, which could prefer the wrong order for broadcasted chains. Split broadcast shape inference into an analyzer-only API that returns BinaryBroadcastShapeInferResult; keep the Call/BlockBuilder wrapper for diagnostics. --- src/relax/op/op_common.cc | 48 +++-- src/relax/op/op_common.h | 30 ++++ src/relax/transform/adjust_matmul_order.cc | 79 ++++++--- .../test_transform_adjust_matmul_order.py | 165 +++++++++++++++--- 4 files changed, 266 insertions(+), 56 deletions(-) diff --git a/src/relax/op/op_common.cc b/src/relax/op/op_common.cc index 61485b09112b..a019b87f3a2b 100644 --- a/src/relax/op/op_common.cc +++ b/src/relax/op/op_common.cc @@ -22,6 +22,7 @@ #include #include +#include namespace tvm { namespace relax { @@ -108,10 +109,10 @@ ffi::Array GetTensorStructInfoFromTuple(const Call& call, cons return tensor_sinfo; } -ffi::Optional> InferBinaryBroadcastShape( - const Call& call, const BlockBuilder& ctx, const ffi::Array& x1_shape, - const ffi::Array& x2_shape) { - arith::Analyzer* analyzer = ctx->GetAnalyzer(); +BinaryBroadcastShapeInferResult InferBinaryBroadcastShape(arith::Analyzer* analyzer, + const ffi::Array& x1_shape, + const ffi::Array& x2_shape) { + BinaryBroadcastShapeInferResult result; int x1_ndim = x1_shape.size(); int x2_ndim = x2_shape.size(); int max_ndim = std::max(x1_ndim, x2_ndim); @@ -132,20 +133,45 @@ ffi::Optional> InferBinaryBroadcastShape( } else if (analyzer->CanProveEqual(dim0, dim1)) { output_shape.push_back(dim0); } else if (int_dim0 && int_dim1 && int_dim0->value != int_dim1->value) { - ctx->ReportFatal(Diagnostic::Error(call) - << "In " << call->op << ", the first input shape at dim " << x1_ndim - i - << " is " << dim0 << " and the second input shape at dim " << x2_ndim - i - << " is " << dim1 << ", which are not broadcastable."); + result.status = BinaryBroadcastShapeInferResult::Status::kConflict; + result.message = [&]() { + std::ostringstream os; + os << "the first input shape at dim " << x1_ndim - i << " is " << dim0 + << " and the second input shape at dim " << x2_ndim - i << " is " << dim1 + << ", which are not broadcastable."; + return ffi::String(os.str()); + }(); + return result; } else { - // Use simple fallback when shape mismatch. - return std::nullopt; + result.status = BinaryBroadcastShapeInferResult::Status::kUnknown; + return result; } } auto& longer_shape = (x1_ndim > x2_ndim) ? x1_shape : x2_shape; for (; i <= max_ndim; ++i) { output_shape.push_back(longer_shape[max_ndim - i]); } - return ffi::Array(output_shape.rbegin(), output_shape.rend()); + result.status = BinaryBroadcastShapeInferResult::Status::kSuccess; + result.shape = ffi::Array(output_shape.rbegin(), output_shape.rend()); + return result; +} + +ffi::Optional> InferBinaryBroadcastShape( + const Call& call, const BlockBuilder& ctx, const ffi::Array& x1_shape, + const ffi::Array& x2_shape) { + auto infer_result = InferBinaryBroadcastShape(ctx->GetAnalyzer(), x1_shape, x2_shape); + if (infer_result.status == BinaryBroadcastShapeInferResult::Status::kConflict) { + TVM_FFI_ICHECK(infer_result.message.has_value()); + ctx->ReportFatal(Diagnostic::Error(call) + << "In " << call->op << ", " << infer_result.message.value()); + } else if (infer_result.status == BinaryBroadcastShapeInferResult::Status::kSuccess) { + TVM_FFI_ICHECK(infer_result.shape.has_value()); + return infer_result.shape.value(); + } else { + // Unknown status, use simple fallback when shape mismatch. + return std::nullopt; + } + TVM_FFI_UNREACHABLE(); } std::vector NormalizeAxes(const Call& call, const BlockBuilder& ctx, int ndim, diff --git a/src/relax/op/op_common.h b/src/relax/op/op_common.h index 774eccfd58dd..6f7de974cbe6 100644 --- a/src/relax/op/op_common.h +++ b/src/relax/op/op_common.h @@ -387,6 +387,36 @@ inline ffi::Optional InferBinaryArithOpOutVDevice(const Call& call, return lhs_vdevice; } +/*! \brief Result of binary broadcast shape inference without diagnostic context. */ +struct BinaryBroadcastShapeInferResult { + enum class Status { + /*! \brief Broadcast output shape is known. */ + kSuccess, + /*! \brief Shapes may be broadcastable but cannot be proved symbolically. */ + kUnknown, + /*! \brief Concrete shapes are not broadcastable. */ + kConflict, + }; + + /*! \brief Inference status. */ + Status status = Status::kUnknown; + /*! \brief Broadcasted shape if status is kSuccess. */ + ffi::Optional> shape; + /*! \brief Human-readable conflict description if status is kConflict. */ + ffi::Optional message; +}; + +/*! + * \brief Infer the output shape for binary broadcast operators. + * \param analyzer The arithmetic analyzer used to prove shape equality. + * \param x1_shape The shape of the first operand. + * \param x2_shape The shape of the second operand. + * \return Inference status and broadcasted shape, or a conflict message. + */ +BinaryBroadcastShapeInferResult InferBinaryBroadcastShape(arith::Analyzer* analyzer, + const ffi::Array& x1_shape, + const ffi::Array& x2_shape); + /*! * \brief Infer the output shape for binary broadcast operators. * \param call The context Call to the operator. diff --git a/src/relax/transform/adjust_matmul_order.cc b/src/relax/transform/adjust_matmul_order.cc index efdacb4e0e83..012c8ce5b71a 100644 --- a/src/relax/transform/adjust_matmul_order.cc +++ b/src/relax/transform/adjust_matmul_order.cc @@ -34,6 +34,7 @@ #include #include +#include "../op/op_common.h" #include "../op/tensor/linear_algebra.h" #include "../op/tensor/manipulate.h" @@ -41,6 +42,27 @@ namespace tvm { namespace relax { namespace { + +ffi::Array GetBatchPrefix(const ffi::Array& shape) { + if (shape.size() <= 2) return {}; + return {shape.begin(), shape.end() - 2}; +} + +PrimExpr ProductDims(const ffi::Array& dims) { + PrimExpr product = IntImm(DataType::Int(64), 1); + for (const auto& dim : dims) product = product * dim; + return product; +} + +ffi::Optional> InferBatchedMatmulBroadcastPrefix( + arith::Analyzer* analyzer, const ffi::Array& x1, const ffi::Array& x2) { + auto infer_result = InferBinaryBroadcastShape(analyzer, x1, x2); + if (infer_result.status == BinaryBroadcastShapeInferResult::Status::kSuccess) { + return infer_result.shape; + } + return std::nullopt; +} + std::tuple)>> CreatePatterns( const Function& func) { auto compile_time_arr = ComputableAtCompileTime(func); @@ -170,11 +192,13 @@ std::tuple)>> }; if (matches.count(pat_permuted_matmul_on_lhs)) { + if (shape_a.size() < 2 || shape_b.size() < 2) return expr; expr_a = permute_last_two_dims(expr_a); expr_b = permute_last_two_dims(expr_b); transpose_shape_last_two_dims(shape_a); transpose_shape_last_two_dims(shape_b); } else if (matches.count(pat_permuted_matmul_on_rhs)) { + if (shape_b.size() < 2 || shape_c.size() < 2) return expr; expr_b = permute_last_two_dims(expr_b); expr_c = permute_last_two_dims(expr_c); transpose_shape_last_two_dims(shape_b); @@ -215,44 +239,49 @@ std::tuple)>> PrimExpr size_M = shape_c[shape_c.size() - 2]; // row of C and col of B PrimExpr size_B = shape_c[shape_c.size() - 1]; // col of C - auto calculate_batch = [](ffi::Array& shape) { - PrimExpr batch = 1; - for (size_t i = 0; i < shape.size() - 2; ++i) { - batch *= shape[i]; - } - return batch; - }; - - PrimExpr batch_A = calculate_batch(shape_a); - PrimExpr batch_B = calculate_batch(shape_b); - PrimExpr batch_C = calculate_batch(shape_c); + arith::Analyzer analyzer; + auto prefix_a = GetBatchPrefix(shape_a); + auto prefix_b = GetBatchPrefix(shape_b); + auto prefix_c = GetBatchPrefix(shape_c); + + auto opt_prefix_ab = InferBatchedMatmulBroadcastPrefix(&analyzer, prefix_a, prefix_b); + if (!opt_prefix_ab) return expr; + auto opt_prefix_bc = InferBatchedMatmulBroadcastPrefix(&analyzer, prefix_b, prefix_c); + if (!opt_prefix_bc) return expr; + auto opt_prefix_outer_lhs = + InferBatchedMatmulBroadcastPrefix(&analyzer, opt_prefix_ab.value(), prefix_c); + if (!opt_prefix_outer_lhs) return expr; + auto opt_prefix_outer_rhs = + InferBatchedMatmulBroadcastPrefix(&analyzer, prefix_a, opt_prefix_bc.value()); + if (!opt_prefix_outer_rhs) return expr; + + PrimExpr batch_ab = ProductDims(opt_prefix_ab.value()); + PrimExpr batch_bc = ProductDims(opt_prefix_bc.value()); + PrimExpr batch_outer_lhs = ProductDims(opt_prefix_outer_lhs.value()); + PrimExpr batch_outer_rhs = ProductDims(opt_prefix_outer_rhs.value()); // Compare naive matmul FLOPs for two evaluation orders of // matmul(A, matmul(B, C)) vs matmul(matmul(A, B), C) // - // Matrix dims (last two axes of each operand): - // A: [N, R] B: [R, M] C: [M, B_last] - // Batch prefixes (product of all leading axes): - // batch_A, batch_B, batch_C + // Matrix dims (last two axes): A [N, R], B [R, M], C [M, B_last] + // Each matmul uses the broadcasted batch prefix of its operands. // // LHS first — matmul(matmul(A, B), C): - // inner matmul(A, B): batch_A * batch_B * N * R * M - // outer matmul(., C): batch_A * batch_B * batch_C * N * M * B_last - // total: batch_A * batch_B * N * M * (R + batch_C * B_last) - PrimExpr ops_with_lhs_first = (size_R + batch_C * size_B) * size_N * size_M * batch_A * batch_B; + // batch_ab * N * R * M + batch_outer_lhs * N * M * B_last + PrimExpr ops_with_lhs_first = + batch_ab * size_N * size_R * size_M + batch_outer_lhs * size_N * size_M * size_B; // RHS first — matmul(A, matmul(B, C)): - // inner matmul(B, C): batch_B * batch_C * R * M * B_last - // outer matmul(A, .): batch_A * batch_B * batch_C * N * R * B_last - // total: batch_B * batch_C * R * B_last * (M + batch_A * N) - PrimExpr ops_with_rhs_first = (size_M + batch_A * size_N) * size_R * size_B * batch_B * batch_C; + // batch_bc * R * M * B_last + batch_outer_rhs * N * R * B_last + PrimExpr ops_with_rhs_first = + batch_bc * size_R * size_M * size_B + batch_outer_rhs * size_N * size_R * size_B; - arith::Analyzer analyzer; analyzer.rewrite_simplify.SetEnabledExtensions(static_cast( analyzer.rewrite_simplify.GetEnabledExtensions() | arith::RewriteSimplifier::Extension::kComparisonOfProductAndSum)); With func_attr_constraint(&analyzer, symbolic_var_constraints); With analyzer_constraint( - &analyzer, size_N > 0 && size_R > 0 && size_M > 0 && size_B > 0); + &analyzer, batch_ab > 0 && batch_bc > 0 && batch_outer_lhs > 0 && batch_outer_rhs > 0 && + size_N > 0 && size_R > 0 && size_M > 0 && size_B > 0); if (analyzer.CanProve(ops_with_lhs_first < ops_with_rhs_first)) { return matmul(matmul(expr_a, expr_b, DataType::Void()), expr_c, DataType::Void()); diff --git a/tests/python/relax/test_transform_adjust_matmul_order.py b/tests/python/relax/test_transform_adjust_matmul_order.py index 6adf1184581b..9600c97bdaac 100644 --- a/tests/python/relax/test_transform_adjust_matmul_order.py +++ b/tests/python/relax/test_transform_adjust_matmul_order.py @@ -42,7 +42,13 @@ def test_compare(self): class TestLHS(Base): - """Prefer (x*A)*B instead of x*(A*B)""" + """Prefer (x*A)*B instead of x*(A*B) + + LHS first - (x*A)*B: + ops = 1*16*2 + 1*2*32 = 96 + RHS first - x*(A*B): + ops = 16*2*32 + 1*16*32 = 1536 + """ @I.ir_module class Before: @@ -70,7 +76,13 @@ def main( class TestRHS(Base): - """Prefer A*(B*x) instead of (A*B)*x""" + """Prefer A*(B*x) instead of (A*B)*x + + LHS first - (A*B)*x: + ops = 32*2*16 + 32*16*1 = 1536 + RHS first - A*(B*x): + ops = 2*16*1 + 32*2*1 = 96 + """ @I.ir_module class Before: @@ -166,6 +178,13 @@ class TestLHSDynamic(Base): This case appears when evaluating LoRA-tuned models with a dynamic rank. + + LHS first - (x*A)*B: + ops = 1*16*lora_r + 1*lora_r*32 = 48*lora_r + RHS first - x*(A*B): + ops = 16*lora_r*32 + 1*16*32 = 512*lora_r + 512 + + 48*lora_r can be proved to be less than 512*lora_r + 512, so the LHS first is preferred. """ @I.ir_module @@ -195,7 +214,15 @@ def main( class TestRHSDynamic(Base): - """Prefer A*(B*x) instead of (A*B)*x""" + """Prefer A*(B*x) instead of (A*B)*x + + LHS first - (A*B)*x: + ops = 32*lora_r*16 + 32*16*1 = 512*lora_r + 512 + RHS first - A*(B*x): + ops = lora_r*16*1 + 32*lora_r*1 = 48*lora_r + + 48*lora_r can be proved to be less than 512*lora_r + 512, so the RHS first is preferred. + """ @I.ir_module class Before: @@ -238,15 +265,16 @@ class TestIdempotentRHSDynamic(Base): class TestDynamicWithBatchSymbolic1(Base): - """Keep existing order when batch_size and lora_r are both symbolic. + """When both batch_size and lora_r are symbolic and it cannot be proven which + is cheaper, LHS or RHS, maintain the existing order. - Before computes `x @ (A @ B)` with + `Before` computes `x * (A * B)` with `x: [batch_size, 1, 16]`, `A: [16, lora_r]`, `B: [lora_r, 32]`. - RHS first (fuse A@B once, no batch on inner matmul): + RHS first - x * (A * B): 16*lora_r*32 + batch_size*1*16*32 = 512*(lora_r + batch_size) - LHS first (both matmuls scale with batch_size): + LHS first - (x * A) * B: batch_size*1*16*lora_r + batch_size*1*lora_r*32 = 48*batch_size*lora_r When `batch_size` and `lora_r` are known at compile-time: @@ -349,15 +377,16 @@ def main( class TestDynamicWithBatchSymbolic2(Base): - """Keep existing order when batch_size and lora_r are both symbolic. + """When both batch_size and lora_r are symbolic and it cannot be proven which + is cheaper, LHS or RHS, maintain the existing order. - Before computes `(A @ B) @ x` with + `Before` computes `(A * B) * x` with `A: [32, lora_r]`, `B: [lora_r, 16]`, `x: [batch_size, 16, 1]`. - LHS first (fuse A@B once, no batch on inner matmul): + LHS first - (A * B) * x: 32*lora_r*16 + batch_size*32*16*1 = 512*(lora_r + batch_size) - RHS first (both matmuls scale with batch_size): + RHS first - A * (B * x): batch_size*lora_r*16*1 + batch_size*32*lora_r*1 = 48*batch_size*lora_r When `batch_size` and `lora_r` are known at compile-time: @@ -516,6 +545,11 @@ class TestRHSPermuteDims(Base): """Prefer (x*A)*B instead of x*(A*B) Like `TestRHS`, but the weights on the RHS are transposed. + + Before: x * (BT * AT) + ops = 16*2*32 + 1*16*32 = 1536 + After: (x * BT) * AT + ops = 1*16*2 + 1*2*32 = 96 """ @I.ir_module @@ -551,6 +585,13 @@ class TestRHSPermuteDimsDynamic(Base): Like `TestRHSPermuteDims`, but the weights on the RHS have a dynamic shape. + + Before: x * (BT * AT) + ops = 16*lora_r*32 + 1*16*32 = 512*lora_r + 512 + After: (x * BT) * AT + ops = 1*16*lora_r + 1*lora_r*32 = 48*lora_r + + 48*lora_r can be proved to be less than 512*lora_r + 512, so the After is preferred. """ @I.ir_module @@ -596,15 +637,15 @@ class TestRHSPermuteDimsWithDynamicBatch(Base): ops_left_to_right = (batch_size + lora_r)*4096*4096 ops_right_to_left = (4096 + 4096)*batch_size*lora_r - Without an upper bound on `lora_r`, we cannot prove which of these - is the preferred execution order. With the upper bound, TVM can - determine the preferred order using the following arithmethic - reasoning. + Without an upper bound on batch_size and`lora_r`, we cannot prove which + of these is the preferred execution order. - (batch_size + lora_r)*4096*4096 < (4096 + 4096)*batch_size*lora_r - (batch_size + lora_r)*2048 < batch_size*lora_r - 1/batch_size + 1/lora_r < 1/2048 + With the upper bound, TVM can determine the preferred order using + the following arithmetic reasoning. + (batch_size + lora_r)*4096*4096 > (4096 + 4096)*batch_size*lora_r + (batch_size + lora_r)*2048 > batch_size*lora_r + 1/batch_size + 1/lora_r > 1/2048 """ @I.ir_module @@ -615,7 +656,12 @@ def main( A: R.Tensor([4096, "lora_r"]), B: R.Tensor(["lora_r", 4096]), ) -> R.Tensor(["batch_size", 4096]): - R.func_attr({"tir_var_upper_bound": {"lora_r": 2048}}) + R.func_attr( + { + "tir_var_upper_bound": {"lora_r": 2048, "batch_size": 2048}, + } + ) + lora_r = T.int64() # noqa: F841 batch_size = T.int64() linear_weight: R.Tensor([4096, 4096]) = R.matmul(A, B) matmul_weight: R.Tensor([4096, 4096]) = R.permute_dims(linear_weight) @@ -630,7 +676,11 @@ def main( A: R.Tensor([4096, "lora_r"]), B: R.Tensor(["lora_r", 4096]), ) -> R.Tensor(["batch_size", 4096]): - R.func_attr({"tir_var_upper_bound": {"lora_r": 2048}}) + R.func_attr( + { + "tir_var_upper_bound": {"lora_r": 2048, "batch_size": 2048}, + } + ) lora_r = T.int64() batch_size = T.int64() B_transpose = R.permute_dims(B) @@ -645,6 +695,11 @@ class TestRHSPermuteDimsDynamicWithSquareMatrix(Base): Like `TestRHSPermuteDims`, but the weights on the RHS have a dynamic shape. + + Before: x * (BT * AT) + ops = 32*lora_r*32 + 1*32*32 = 1024*lora_r + 1024 + After: (x * BT) * AT + ops = 1*32*lora_r + 1*lora_r*32 = 64*lora_r """ @I.ir_module @@ -676,6 +731,76 @@ def main( return x +class TestBatchedBroadcastPreferLHSFirst(Base): + """Use broadcasted batch prefix per matmul, not independent prefix products. + + Example with broadcast batch axes: A:[2,1,1], B:[2,1,2], C:[2,2,3]. + + LHS first: (A * B) * C + ops = 2*1*1*2 + 2*1*2*3 = 16 + RHS first: A * (B * C) + ops = 2*1*2*3 + 2*1*1*3 = 18 + """ + + @I.ir_module + class Before: + @R.function + def main( + A: R.Tensor([2, 1, 1]), + B: R.Tensor([2, 1, 2]), + C: R.Tensor([2, 2, 3]), + ) -> R.Tensor([2, 1, 3]): + out: R.Tensor([2, 1, 3]) = R.matmul(A, R.matmul(B, C)) + return out + + @I.ir_module + class Expected: + @R.function + def main( + A: R.Tensor([2, 1, 1]), + B: R.Tensor([2, 1, 2]), + C: R.Tensor([2, 2, 3]), + ) -> R.Tensor([2, 1, 3]): + temp: R.Tensor([2, 1, 2]) = R.matmul(A, B) + out: R.Tensor([2, 1, 3]) = R.matmul(temp, C) + return out + + +class TestBatchedSharedPrefixPreferLHSFirst(Base): + """All operands share a nontrivial batch prefix [2, 3]. + + Shapes: A:[2,3,4,5], B:[2,3,5,6], C:[2,3,6,7] + + LHS first: + ops = 6*4*5*6 + 6*4*6*7 = 1728 + RHS first: + ops = 6*5*6*7 + 6*4*5*7 = 2100 + """ + + @I.ir_module + class Before: + @R.function + def main( + A: R.Tensor([2, 3, 4, 5]), + B: R.Tensor([2, 3, 5, 6]), + C: R.Tensor([2, 3, 6, 7]), + ) -> R.Tensor([2, 3, 4, 7]): + out: R.Tensor([2, 3, 4, 7]) = R.matmul(A, R.matmul(B, C)) + return out + + @I.ir_module + class Expected: + @R.function + def main( + A: R.Tensor([2, 3, 4, 5]), + B: R.Tensor([2, 3, 5, 6]), + C: R.Tensor([2, 3, 6, 7]), + ) -> R.Tensor([2, 3, 4, 7]): + temp: R.Tensor([2, 3, 4, 6]) = R.matmul(A, B) + out: R.Tensor([2, 3, 4, 7]) = R.matmul(temp, C) + return out + + class TestAdjustMatmulOrderAttentionBlock: """AdjustMatmulOrder preserves numerics on a batched attention block.