diff --git a/src/te/operation/create_primfunc.cc b/src/te/operation/create_primfunc.cc index 36d8e76c2423..d7503b8f4f9c 100644 --- a/src/te/operation/create_primfunc.cc +++ b/src/te/operation/create_primfunc.cc @@ -83,9 +83,10 @@ struct CreateFuncInfo { } }; -BlockRealize GenerateBlockFromTensor(const te::ComputeOp& compute_op, const te::Tensor& tensor, - Array bindings, PrimExpr expr_body, - CreateFuncInfo* info, arith::Analyzer* analyzer) { +BlockRealize GenerateBlockFromTensors(const te::ComputeOp& compute_op, + const Array& tensors, Array bindings, + PrimExpr expr_body, CreateFuncInfo* info, + arith::Analyzer* analyzer) { // Step 1. Push_back data_par axis and reduce_axis into block_vars. Array iter_vars; std::unordered_map var_map; @@ -105,16 +106,22 @@ BlockRealize GenerateBlockFromTensor(const te::ComputeOp& compute_op, const te:: f_push_block_vars(compute_op->axis); f_push_block_vars(compute_op->reduce_axis); - // Step 2. Declare buffer and update op2buffers - Buffer buffer = decl_buffer(tensor->shape, tensor->dtype, tensor->GetNameHint(), "global"); - info->tensor2buffers[tensor] = buffer; - - // Step 3. Add Buffer to root_alloc - if (!info->IsArg(tensor)) { - info->root_alloc.push_back(buffer); + // Step 2. + // - Declare buffers + // - Update `op2buffers` + // - Add the non-argument tensors to `alloc_buffer` of the root block + Array buffers; + for (const te::Tensor& tensor : tensors) { + Buffer buffer = decl_buffer(tensor->shape, tensor->dtype, tensor->GetNameHint(), "global"); + info->tensor2buffers[tensor] = buffer; + buffers.push_back(buffer); + + if (!info->IsArg(tensor)) { + info->root_alloc.push_back(info->tensor2buffers[tensor]); + } } - // Step 4. Calculate indices for BufferStore + // Step 3. Calculate indices for BufferStore Array indices; indices.reserve(compute_op->axis.size()); for (const IterVar& iter_var : compute_op->axis) { @@ -123,26 +130,75 @@ BlockRealize GenerateBlockFromTensor(const te::ComputeOp& compute_op, const te:: indices.push_back(it->second); } - // Step 5. Create block body. + // Step 4. Create block body. + String block_name{nullptr}; Optional init = NullOpt; Stmt body; if (const auto* reduce = expr_body.as()) { // Case 1. Reduce compute - ICHECK_EQ(reduce->source.size(), 1); - const PrimExpr& lhs = BufferLoad(buffer, indices); - const PrimExpr& rhs = Substitute(info->transformer(reduce->source[0]), var_map); - ICHECK(lhs->dtype == rhs->dtype); - const PrimExpr& reduce_body = reduce->combiner.get()->operator()({lhs}, {rhs})[0]; - const PrimExpr& init_body = reduce->combiner->identity_element[0]; - body = BufferStore(buffer, analyzer->Simplify(reduce_body), indices); - init = BufferStore(buffer, analyzer->Simplify(init_body), indices); + block_name = compute_op->name; + int n_buffers = buffers.size(); + + Array lhs; + Array rhs; + lhs.reserve(n_buffers); + rhs.reserve(n_buffers); + + // Make the LHS operands and RHS operands: + // - A LHS operand is the buffer storing the reduction result, with corresponding indices. + // - A RHS operand is the value to be reduced. + for (int i = 0; i < n_buffers; ++i) { + const PrimExpr& left = BufferLoad(buffers[i], indices); + const PrimExpr& right = + analyzer->Simplify(Substitute(info->transformer(reduce->source[i]), var_map)); + lhs.push_back(left); + rhs.push_back(right); + ICHECK_EQ(left->dtype, right->dtype); + } + + Array temp_vars; + Array body_stmts; + Array init_stmts; + temp_vars.reserve(n_buffers); + body_stmts.reserve(n_buffers); + init_stmts.reserve(n_buffers); + + // - When there is only one buffer, we directly create a BufferStore which stores "combiner(lhs, + // rhs)" into the target buffer position. + // - In case there are multiple buffers, to avoid incorrect results, we create some intermediate + // variables and use LetStmts to bind the variables with "combiner(lhs, rhs)". After that, we + // then store the value of the variables into the target buffer positions. + for (int i = 0; i < n_buffers; ++i) { + const Buffer& buffer = buffers[i]; + init_stmts.push_back(BufferStore(buffer, reduce->combiner->identity_element[i], indices)); + PrimExpr value{nullptr}; + if (n_buffers > 1) { + temp_vars.push_back(Var("v_" + buffer->name, PrimType(lhs[i].dtype()))); + value = temp_vars.back(); + } else { + value = reduce->combiner.get()->operator()(lhs, rhs)[i]; + } + body_stmts.push_back(BufferStore(buffer, value, indices)); + } + + init = SeqStmt::Flatten(init_stmts); + body = SeqStmt::Flatten(body_stmts); + if (n_buffers > 1) { + // When there are multiple buffers, we wrap the body with LetStmts. + for (int i = n_buffers - 1; i >= 0; --i) { + PrimExpr value = reduce->combiner.get()->operator()(lhs, rhs)[i]; + body = LetStmt(temp_vars[i], std::move(value), std::move(body)); + } + } } else { // Case 2. Data parallel compute + ICHECK_EQ(tensors.size(), 1); + block_name = info->GetUniqueName(tensors[0]->GetNameHint()); const PrimExpr& compute_body = Substitute(info->transformer(expr_body), var_map); - body = BufferStore(buffer, analyzer->Simplify(compute_body), indices); + body = BufferStore(info->tensor2buffers[tensors[0]], analyzer->Simplify(compute_body), indices); } - // Step 6. Add script_parsing_detect_access attr for auto complete the whole IR. + // Step 5. Add script_parsing_detect_access attr for auto complete the whole IR. Map annotations; auto mutate_attr = [&info](const ObjectRef& value) -> ObjectRef { if (const auto* tensor_value = value.as()) { @@ -166,14 +222,14 @@ BlockRealize GenerateBlockFromTensor(const te::ComputeOp& compute_op, const te:: // Set script_parsing_detect_access annotations.Set(tir::attr::script_parsing_detect_access, IntImm(DataType::Int(32), 3)); - // Step 7. Create Block and BlockRealize. + // Step 6. Create Block and BlockRealize. return BlockRealize(/*iter_values=*/std::move(bindings), /*predicate=*/Bool(true), /*block=*/ Block(/*iter_vars=*/std::move(iter_vars), /*reads=*/{}, /*writes=*/{}, - /*name_hint=*/info->GetUniqueName(tensor->GetNameHint()), + /*name_hint=*/block_name, /*body=*/std::move(body), /*init=*/std::move(init), /*alloc_buffers=*/{}, @@ -192,12 +248,38 @@ Stmt GenerateStmtFromCompute(const te::ComputeOp& compute_op, CreateFuncInfo* in } // Step 2. Generate block bodies. Array seq_stmt; - for (int i = 0; i < compute_op->num_outputs(); ++i) { - const te::Tensor& tensor = compute_op.output(i); - PrimExpr expr_body = compute_op->body[i]; - seq_stmt.push_back(GenerateBlockFromTensor(compute_op, tensor, bindings, std::move(expr_body), - info, analyzer)); + if (compute_op->body[0]->IsInstance()) { + auto f_reducer_equal = [](const ReduceNode* a, const ReduceNode* b) -> bool { + return a->combiner.same_as(b->combiner) && // + a->source.same_as(b->source) && // + a->axis.same_as(b->axis) && // + a->condition.same_as(b->condition) && // + ((a->init.empty() && b->init.empty()) || a->init.same_as(b->init)); + }; + + PrimExpr expr_body = compute_op->body[0]; + Array tensors = {compute_op.output(0)}; + const tir::ReduceNode* reduce = expr_body.as(); + // specially handle reduction inline for multiplre reductions. + for (size_t k = 1; k < compute_op->body.size(); ++k) { + const tir::ReduceNode* reduce_ = compute_op->body[k].as(); + ICHECK(reduce_); + ICHECK(f_reducer_equal(reduce_, reduce)) + << "The Reduce inputs of ComputeOp should have the same attribute except value_index"; + tensors.push_back(compute_op.output(k)); + } + + seq_stmt.push_back(GenerateBlockFromTensors(compute_op, tensors, bindings, std::move(expr_body), + info, analyzer)); + } else { + for (int i = 0; i < compute_op->num_outputs(); ++i) { + const te::Tensor& tensor = compute_op.output(i); + PrimExpr expr_body = compute_op->body[i]; + seq_stmt.push_back(GenerateBlockFromTensors(compute_op, {tensor}, bindings, + std::move(expr_body), info, analyzer)); + } } + Stmt body = SeqStmt::Flatten(seq_stmt); // Step 3. Generate loop nesting. diff --git a/tests/python/unittest/test_te_create_primfunc.py b/tests/python/unittest/test_te_create_primfunc.py index 68ea2ab461f5..48082c44a4ab 100644 --- a/tests/python/unittest/test_te_create_primfunc.py +++ b/tests/python/unittest/test_te_create_primfunc.py @@ -359,6 +359,108 @@ def test_tensor_attr(): tvm.ir.assert_structural_equal(func, rt_func) +def te_argmax_idx_val(): + def f_combine(x, y): + lhs = tvm.tir.Select((x[1] >= y[1]), x[0], y[0]) + rhs = tvm.tir.Select((x[1] >= y[1]), x[1], y[1]) + return lhs, rhs + + def f_identity(dtype0: tvm.DataType, dtype1: tvm.DataType): + return tvm.tir.const(-1, dtype0), tvm.te.min_value(dtype1) + + argmax = te.comm_reducer(f_combine, f_identity, name="argmax") + + m = te.var("m") + n = te.var("n") + idx = te.placeholder((m, n), name="idx", dtype="int32") + val = te.placeholder((m, n), name="val", dtype="float32") + k = te.reduce_axis((0, n), "k") + max_idx, max_val = te.compute( + (m,), lambda i: argmax((idx[i, k], val[i, k]), axis=k), name="argmax" + ) + return [idx, val, max_idx, max_val] + + +@T.prim_func +def tir_argmax_idx_val( + var_idx: T.handle, var_val: T.handle, var_argmax_v0: T.handle, var_argmax_v1: T.handle +) -> None: + T.func_attr({"global_symbol": "main", "tir.noalias": True}) + m = T.var("int32") + n = T.var("int32") + idx = T.match_buffer(var_idx, [m, n], dtype="int32") + val = T.match_buffer(var_val, [m, n], dtype="float32") + argmax_v0 = T.match_buffer(var_argmax_v0, [m], dtype="int32") + argmax_v1 = T.match_buffer(var_argmax_v1, [m], dtype="float32") + for i0, i1 in T.grid(m, n): + with T.block("argmax"): + i, k = T.axis.remap("SR", [i0, i1]) + T.reads(argmax_v1[i], val[i, k], argmax_v0[i], idx[i, k]) + T.writes(argmax_v0[i], argmax_v1[i]) + with T.init(): + argmax_v0[i] = T.int32(-1) + argmax_v1[i] = T.min_value("float32") + v_argmax_v0: T.int32 = T.Select(argmax_v1[i] >= val[i, k], argmax_v0[i], idx[i, k]) + v_argmax_v1: T.float32 = T.Select(argmax_v1[i] >= val[i, k], argmax_v1[i], val[i, k]) + argmax_v0[i] = v_argmax_v0 + argmax_v1[i] = v_argmax_v1 + + +def te_argmax_val_idx(): + def f_combine(x, y): + lhs = tvm.tir.Select((x[0] >= y[0]), x[0], y[0]) + rhs = tvm.tir.Select((x[0] >= y[0]), x[1], y[1]) + return lhs, rhs + + def f_identity(dtype0: tvm.DataType, dtype1: tvm.DataType): + return tvm.te.min_value(dtype0), tvm.tir.const(-1, dtype1) + + argmax = te.comm_reducer(f_combine, f_identity, name="argmax") + + m = te.var("m") + n = te.var("n") + val = te.placeholder((m, n), name="val", dtype="float32") + idx = te.placeholder((m, n), name="idx", dtype="int32") + k = te.reduce_axis((0, n), "k") + max_val, max_idx = te.compute( + (m,), lambda i: argmax((val[i, k], idx[i, k]), axis=k), name="argmax" + ) + return [val, idx, max_val, max_idx] + + +@T.prim_func +def tir_argmax_val_idx( + var_val: T.handle, var_idx: T.handle, var_argmax_v0: T.handle, var_argmax_v1: T.handle +) -> None: + T.func_attr({"global_symbol": "main", "tir.noalias": True}) + m = T.var("int32") + n = T.var("int32") + val = T.match_buffer(var_val, [m, n], dtype="float32") + idx = T.match_buffer(var_idx, [m, n], dtype="int32") + argmax_v0 = T.match_buffer(var_argmax_v0, [m], dtype="float32") + argmax_v1 = T.match_buffer(var_argmax_v1, [m], dtype="int32") + for i0, i1 in T.grid(m, n): + with T.block("argmax"): + i, k = T.axis.remap("SR", [i0, i1]) + T.reads(argmax_v0[i], val[i, k], argmax_v1[i], idx[i, k]) + T.writes(argmax_v0[i], argmax_v1[i]) + with T.init(): + argmax_v0[i] = T.min_value("float32") + argmax_v1[i] = T.int32(-1) + v_argmax_v0: T.float32 = T.Select(argmax_v0[i] >= val[i, k], argmax_v0[i], val[i, k]) + v_argmax_v1: T.int32 = T.Select(argmax_v0[i] >= val[i, k], argmax_v1[i], idx[i, k]) + argmax_v0[i] = v_argmax_v0 + argmax_v1[i] = v_argmax_v1 + + +def test_argmax_idx_val(): + _check_workload(te_argmax_idx_val, tir_argmax_idx_val) + + +def test_argmax_val_idx(): + _check_workload(te_argmax_val_idx, tir_argmax_val_idx) + + if __name__ == "__main__": test_unique_name() test_matmul() @@ -371,3 +473,5 @@ def test_tensor_attr(): test_constant() test_select_simplify() test_tensor_attr() + test_argmax_idx_val() + test_argmax_val_idx()