Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
48 changes: 37 additions & 11 deletions src/relax/op/op_common.cc
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
#include <tvm/ffi/cast.h>

#include <algorithm>
#include <sstream>

namespace tvm {
namespace relax {
Expand Down Expand Up @@ -108,10 +109,10 @@ ffi::Array<TensorStructInfo> GetTensorStructInfoFromTuple(const Call& call, cons
return tensor_sinfo;
}

ffi::Optional<ffi::Array<PrimExpr>> InferBinaryBroadcastShape(
const Call& call, const BlockBuilder& ctx, const ffi::Array<PrimExpr>& x1_shape,
const ffi::Array<PrimExpr>& x2_shape) {
arith::Analyzer* analyzer = ctx->GetAnalyzer();
BinaryBroadcastShapeInferResult InferBinaryBroadcastShape(arith::Analyzer* analyzer,
const ffi::Array<PrimExpr>& x1_shape,
const ffi::Array<PrimExpr>& x2_shape) {
BinaryBroadcastShapeInferResult result;
int x1_ndim = x1_shape.size();
int x2_ndim = x2_shape.size();
int max_ndim = std::max(x1_ndim, x2_ndim);
Expand All @@ -132,20 +133,45 @@ ffi::Optional<ffi::Array<PrimExpr>> InferBinaryBroadcastShape(
} else if (analyzer->CanProveEqual(dim0, dim1)) {
output_shape.push_back(dim0);
} else if (int_dim0 && int_dim1 && int_dim0->value != int_dim1->value) {
ctx->ReportFatal(Diagnostic::Error(call)
<< "In " << call->op << ", the first input shape at dim " << x1_ndim - i
<< " is " << dim0 << " and the second input shape at dim " << x2_ndim - i
<< " is " << dim1 << ", which are not broadcastable.");
result.status = BinaryBroadcastShapeInferResult::Status::kConflict;
result.message = [&]() {
std::ostringstream os;
os << "the first input shape at dim " << x1_ndim - i << " is " << dim0
<< " and the second input shape at dim " << x2_ndim - i << " is " << dim1
<< ", which are not broadcastable.";
return ffi::String(os.str());
}();
return result;
} else {
// Use simple fallback when shape mismatch.
return std::nullopt;
result.status = BinaryBroadcastShapeInferResult::Status::kUnknown;
return result;
}
}
auto& longer_shape = (x1_ndim > x2_ndim) ? x1_shape : x2_shape;
for (; i <= max_ndim; ++i) {
output_shape.push_back(longer_shape[max_ndim - i]);
}
return ffi::Array<PrimExpr>(output_shape.rbegin(), output_shape.rend());
result.status = BinaryBroadcastShapeInferResult::Status::kSuccess;
result.shape = ffi::Array<PrimExpr>(output_shape.rbegin(), output_shape.rend());
return result;
}

ffi::Optional<ffi::Array<PrimExpr>> InferBinaryBroadcastShape(
const Call& call, const BlockBuilder& ctx, const ffi::Array<PrimExpr>& x1_shape,
const ffi::Array<PrimExpr>& x2_shape) {
auto infer_result = InferBinaryBroadcastShape(ctx->GetAnalyzer(), x1_shape, x2_shape);
if (infer_result.status == BinaryBroadcastShapeInferResult::Status::kConflict) {
TVM_FFI_ICHECK(infer_result.message.has_value());
ctx->ReportFatal(Diagnostic::Error(call)
<< "In " << call->op << ", " << infer_result.message.value());
} else if (infer_result.status == BinaryBroadcastShapeInferResult::Status::kSuccess) {
TVM_FFI_ICHECK(infer_result.shape.has_value());
return infer_result.shape.value();
} else {
// Unknown status, use simple fallback when shape mismatch.
return std::nullopt;
}
TVM_FFI_UNREACHABLE();
}

std::vector<int> NormalizeAxes(const Call& call, const BlockBuilder& ctx, int ndim,
Expand Down
30 changes: 30 additions & 0 deletions src/relax/op/op_common.h
Original file line number Diff line number Diff line change
Expand Up @@ -387,6 +387,36 @@ inline ffi::Optional<VDevice> InferBinaryArithOpOutVDevice(const Call& call,
return lhs_vdevice;
}

/*! \brief Result of binary broadcast shape inference without diagnostic context. */
struct BinaryBroadcastShapeInferResult {
enum class Status {
/*! \brief Broadcast output shape is known. */
kSuccess,
/*! \brief Shapes may be broadcastable but cannot be proved symbolically. */
kUnknown,
/*! \brief Concrete shapes are not broadcastable. */
kConflict,
};

/*! \brief Inference status. */
Status status = Status::kUnknown;
/*! \brief Broadcasted shape if status is kSuccess. */
ffi::Optional<ffi::Array<PrimExpr>> shape;
/*! \brief Human-readable conflict description if status is kConflict. */
ffi::Optional<ffi::String> message;
};

/*!
* \brief Infer the output shape for binary broadcast operators.
* \param analyzer The arithmetic analyzer used to prove shape equality.
* \param x1_shape The shape of the first operand.
* \param x2_shape The shape of the second operand.
* \return Inference status and broadcasted shape, or a conflict message.
*/
BinaryBroadcastShapeInferResult InferBinaryBroadcastShape(arith::Analyzer* analyzer,
const ffi::Array<PrimExpr>& x1_shape,
const ffi::Array<PrimExpr>& x2_shape);

/*!
* \brief Infer the output shape for binary broadcast operators.
* \param call The context Call to the operator.
Expand Down
132 changes: 103 additions & 29 deletions src/relax/transform/adjust_matmul_order.cc
Original file line number Diff line number Diff line change
Expand Up @@ -34,13 +34,35 @@
#include <unordered_set>
#include <vector>

#include "../op/op_common.h"
#include "../op/tensor/linear_algebra.h"
#include "../op/tensor/manipulate.h"

namespace tvm {
namespace relax {

namespace {

ffi::Array<PrimExpr> GetBatchPrefix(const ffi::Array<PrimExpr>& shape) {
if (shape.size() <= 2) return {};
return {shape.begin(), shape.end() - 2};
}

PrimExpr ProductDims(const ffi::Array<PrimExpr>& dims) {
PrimExpr product = IntImm(DataType::Int(64), 1);
for (const auto& dim : dims) product = product * dim;
return product;
}

ffi::Optional<ffi::Array<PrimExpr>> InferBatchedMatmulBroadcastPrefix(
arith::Analyzer* analyzer, const ffi::Array<PrimExpr>& x1, const ffi::Array<PrimExpr>& x2) {
auto infer_result = InferBinaryBroadcastShape(analyzer, x1, x2);
if (infer_result.status == BinaryBroadcastShapeInferResult::Status::kSuccess) {
return infer_result.shape;
}
return std::nullopt;
}

std::tuple<DFPattern, ffi::TypedFunction<Expr(Expr, ffi::Map<DFPattern, Expr>)>> CreatePatterns(
const Function& func) {
auto compile_time_arr = ComputableAtCompileTime(func);
Expand Down Expand Up @@ -141,20 +163,46 @@ std::tuple<DFPattern, ffi::TypedFunction<Expr(Expr, ffi::Map<DFPattern, Expr>)>>
auto shape_b = opt_shape_b.value();
auto shape_c = opt_shape_c.value();

auto permute_last_two_dims = [&](Expr expr) -> Expr {
auto opt_shape = get_shape(expr);
if (!opt_shape) return expr;

size_t ndim = opt_shape.value().size();
TVM_FFI_ICHECK_GE(ndim, 2);

ffi::Optional<ffi::Array<int64_t>> axes;

if (ndim == 2) {
// Pass none axes to permute_dims for simple transpose of 2D tensors.
axes = std::nullopt;
} else {
ffi::Array<int64_t> axes_array;
for (size_t i = 0; i < ndim; ++i) axes_array.push_back(i);
axes_array.Set(ndim - 1, ndim - 2);
axes_array.Set(ndim - 2, ndim - 1);
axes = ffi::Optional<ffi::Array<int64_t>>(axes_array);
}
return permute_dims(std::move(expr), axes);
};

auto transpose_shape_last_two_dims = [&](ffi::Array<PrimExpr>& shape) {
PrimExpr last_dim_shape = shape[shape.size() - 1];
shape.Set(shape.size() - 1, shape[shape.size() - 2]);
shape.Set(shape.size() - 2, last_dim_shape);
};

if (matches.count(pat_permuted_matmul_on_lhs)) {
expr_a = permute_dims(expr_a, std::nullopt);
expr_b = permute_dims(expr_b, std::nullopt);
TVM_FFI_ICHECK_EQ(shape_a.size(), 2);
TVM_FFI_ICHECK_EQ(shape_b.size(), 2);
shape_a = {shape_a[1], shape_a[0]};
shape_b = {shape_b[1], shape_b[0]};
if (shape_a.size() < 2 || shape_b.size() < 2) return expr;
expr_a = permute_last_two_dims(expr_a);
expr_b = permute_last_two_dims(expr_b);
transpose_shape_last_two_dims(shape_a);
transpose_shape_last_two_dims(shape_b);
} else if (matches.count(pat_permuted_matmul_on_rhs)) {
expr_b = permute_dims(expr_b, std::nullopt);
expr_c = permute_dims(expr_c, std::nullopt);
TVM_FFI_ICHECK_EQ(shape_b.size(), 2);
TVM_FFI_ICHECK_EQ(shape_c.size(), 2);
shape_b = {shape_b[1], shape_b[0]};
shape_c = {shape_c[1], shape_c[0]};
if (shape_b.size() < 2 || shape_c.size() < 2) return expr;
expr_b = permute_last_two_dims(expr_b);
expr_c = permute_last_two_dims(expr_c);
transpose_shape_last_two_dims(shape_b);
transpose_shape_last_two_dims(shape_c);
}
Comment thread
ConvolutedDog marked this conversation as resolved.

// If two of the three are compile-time, group those two values
Expand All @@ -166,13 +214,7 @@ std::tuple<DFPattern, ffi::TypedFunction<Expr(Expr, ffi::Map<DFPattern, Expr>)>>
}

// Otherwise, select the order that reduces the total number of
// operations required, assuming a naive matmul.

// Matmul on LHS: ([N,R]*[R,M]) * [M,batch]
// Matmul on RHS: [N,R] * ([R,M]*[M,batch])
//
// LHS first: `N*R*M + N*M*batch = N*M*(R+batch)`
// RHS first: `N*R*batch + R*M*batch = (N+M)*R*batch`
// operations required, assuming a naive matmul (see below).

if (shape_a.size() == 1) {
shape_a = {IntImm(shape_a[0].dtype(), 1), shape_a[0]};
Expand All @@ -192,30 +234,62 @@ std::tuple<DFPattern, ffi::TypedFunction<Expr(Expr, ffi::Map<DFPattern, Expr>)>>
shape_c = {shape_c[0], IntImm(shape_c[0].dtype(), 1)};
}

auto size_N = shape_a[shape_a.size() - 2];
auto size_R = shape_a[shape_a.size() - 1];
auto size_M = shape_c[shape_c.size() - 2];
auto size_B = shape_c[shape_c.size() - 1];

auto ops_with_lhs_first = (size_R + size_B) * size_N * size_M;
auto ops_with_rhs_first = (size_M + size_N) * size_R * size_B;
PrimExpr size_N = shape_a[shape_a.size() - 2]; // row of A
PrimExpr size_R = shape_a[shape_a.size() - 1]; // col of A and row of B
PrimExpr size_M = shape_c[shape_c.size() - 2]; // row of C and col of B
PrimExpr size_B = shape_c[shape_c.size() - 1]; // col of C

arith::Analyzer analyzer;
auto prefix_a = GetBatchPrefix(shape_a);
auto prefix_b = GetBatchPrefix(shape_b);
auto prefix_c = GetBatchPrefix(shape_c);

auto opt_prefix_ab = InferBatchedMatmulBroadcastPrefix(&analyzer, prefix_a, prefix_b);
if (!opt_prefix_ab) return expr;
auto opt_prefix_bc = InferBatchedMatmulBroadcastPrefix(&analyzer, prefix_b, prefix_c);
if (!opt_prefix_bc) return expr;
auto opt_prefix_outer_lhs =
InferBatchedMatmulBroadcastPrefix(&analyzer, opt_prefix_ab.value(), prefix_c);
if (!opt_prefix_outer_lhs) return expr;
auto opt_prefix_outer_rhs =
InferBatchedMatmulBroadcastPrefix(&analyzer, prefix_a, opt_prefix_bc.value());
if (!opt_prefix_outer_rhs) return expr;

PrimExpr batch_ab = ProductDims(opt_prefix_ab.value());
PrimExpr batch_bc = ProductDims(opt_prefix_bc.value());
PrimExpr batch_outer_lhs = ProductDims(opt_prefix_outer_lhs.value());
PrimExpr batch_outer_rhs = ProductDims(opt_prefix_outer_rhs.value());

// Compare naive matmul FLOPs for two evaluation orders of
// matmul(A, matmul(B, C)) vs matmul(matmul(A, B), C)
//
// Matrix dims (last two axes): A [N, R], B [R, M], C [M, B_last]
// Each matmul uses the broadcasted batch prefix of its operands.
//
// LHS first — matmul(matmul(A, B), C):
// batch_ab * N * R * M + batch_outer_lhs * N * M * B_last
PrimExpr ops_with_lhs_first =
batch_ab * size_N * size_R * size_M + batch_outer_lhs * size_N * size_M * size_B;
// RHS first — matmul(A, matmul(B, C)):
// batch_bc * R * M * B_last + batch_outer_rhs * N * R * B_last
PrimExpr ops_with_rhs_first =
batch_bc * size_R * size_M * size_B + batch_outer_rhs * size_N * size_R * size_B;

analyzer.rewrite_simplify.SetEnabledExtensions(static_cast<arith::RewriteSimplifier::Extension>(
analyzer.rewrite_simplify.GetEnabledExtensions() |
arith::RewriteSimplifier::Extension::kComparisonOfProductAndSum));
With<arith::ConstraintContext> func_attr_constraint(&analyzer, symbolic_var_constraints);
With<arith::ConstraintContext> analyzer_constraint(
&analyzer, size_N > 0 && size_R > 0 && size_M > 0 && size_B > 0);
&analyzer, batch_ab > 0 && batch_bc > 0 && batch_outer_lhs > 0 && batch_outer_rhs > 0 &&
size_N > 0 && size_R > 0 && size_M > 0 && size_B > 0);

if (analyzer.CanProve(ops_with_lhs_first < ops_with_rhs_first)) {
return matmul(matmul(expr_a, expr_b, DataType::Void()), expr_c, DataType::Void());
} else if (analyzer.CanProve(ops_with_rhs_first < ops_with_lhs_first)) {
return matmul(expr_a, matmul(expr_b, expr_c, DataType::Void()), DataType::Void());
}

// If we cannot determine which order is best, keep the existing
// order.
// If we cannot determine which order is best, keep the existing order.
return expr;
};

Expand Down
Loading
Loading