diff --git a/python/tvm/contrib/cutlass/attention_operation.py b/python/tvm/contrib/cutlass/attention_operation.py index 55c1ccd6163d..ba3f4872c9f8 100644 --- a/python/tvm/contrib/cutlass/attention_operation.py +++ b/python/tvm/contrib/cutlass/attention_operation.py @@ -139,7 +139,11 @@ def instantiate_attention_template(attrs): } CHECK(Attention::check_supported(p)); - kernel_fn<<>>(p); + auto func = tvm::runtime::Registry::Get("runtime.get_cuda_stream"); + ICHECK(func != nullptr); + cudaStream_t stream = static_cast((*func)().operator void*()); + + kernel_fn<<>>(p); if (accumulator_buf_allocated) { cudaFree(p.output_accum_ptr); diff --git a/python/tvm/contrib/cutlass/conv2d_operation.py b/python/tvm/contrib/cutlass/conv2d_operation.py index 8f85fc5382b3..77f4449db232 100644 --- a/python/tvm/contrib/cutlass/conv2d_operation.py +++ b/python/tvm/contrib/cutlass/conv2d_operation.py @@ -423,7 +423,12 @@ def instantiate_conv2d_template(attrs): status = conv2d_op.initialize(arguments, workspace.get()); CHECK(status == cutlass::Status::kSuccess); ${split_k_update} - status = conv2d_op(); + + auto func = tvm::runtime::Registry::Get("runtime.get_cuda_stream"); + ICHECK(func != nullptr); + cudaStream_t stream = static_cast((*func)().operator void*()); + + status = conv2d_op(stream); CHECK(status == cutlass::Status::kSuccess); ${split_k_reduction} """ diff --git a/python/tvm/contrib/cutlass/gemm_operation.py b/python/tvm/contrib/cutlass/gemm_operation.py index 86f6e977197d..3fa6e9be8d6e 100644 --- a/python/tvm/contrib/cutlass/gemm_operation.py +++ b/python/tvm/contrib/cutlass/gemm_operation.py @@ -344,7 +344,12 @@ def instantiate_gemm_template(attrs): CHECK(status == cutlass::Status::kSuccess); status = gemm_op.initialize(arguments, workspace.get()); CHECK(status == cutlass::Status::kSuccess); - status = gemm_op(); + + auto func = tvm::runtime::Registry::Get("runtime.get_cuda_stream"); + ICHECK(func != nullptr); + cudaStream_t stream = static_cast((*func)().operator void*()); + + status = gemm_op(stream); CHECK(status == cutlass::Status::kSuccess); """ op_type = attrs["op_type"] @@ -416,38 +421,34 @@ def emit_fp16A_int4B_matmul(attrs): int m = ${A_arg}->shape[${batch_offset}]; int n = ${B_arg}->shape[1] * 2; int k = ${B_arg}->shape[0]; + + auto func = tvm::runtime::Registry::Get("runtime.get_cuda_stream"); + ICHECK(func != nullptr); + cudaStream_t stream = static_cast((*func)().operator void*()); """, attrs, ) template = """ ${template_common} - gemm_fp16_int4(static_cast(${A_arg}->data), - static_cast(${B_arg}->data), - static_cast(${scales_arg}->data), - static_cast(out0->data), - m, n, k, nullptr, 0, nullptr); -""" - - template_bias = """ - ${template_common} - gemm_fp16_int4_bias(static_cast(${A_arg}->data), - static_cast(${B_arg}->data), - static_cast(${scales_arg}->data), - static_cast(${bias_arg}->data), - static_cast(out0->data), - m, n, k, nullptr, 0, nullptr); + gemm_fp16_int_bias_act(static_cast(${A_arg}->data), + static_cast<${weight_dtype}*>(${B_arg}->data), + static_cast(${scales_arg}->data), + ${bias}, + static_cast(out0->data), + "${activation}", + m, n, k, ${bias_stride}, nullptr, 0, stream); """ template_residual = """ ${template_common} - gemm_fp16_int4_bias_act_residual(static_cast(${A_arg}->data), - static_cast(${B_arg}->data), - static_cast(${scales_arg}->data), - ${bias}, - static_cast(${residual_arg}->data), - static_cast(out0->data), "${activation}", "${binary_op}", "${unary_op}", - m, n, k, nullptr, 0, nullptr); + gemm_fp16_int_bias_act_residual(static_cast(${A_arg}->data), + static_cast<${weight_dtype}*>(${B_arg}->data), + static_cast(${scales_arg}->data), + ${bias}, + static_cast(${residual_arg}->data), + static_cast(out0->data), "${activation}", "${binary_op}", "${unary_op}", + m, n, k, nullptr, 0, stream); """ if "residual_arg" in attrs and "bias_arg" in attrs: diff --git a/python/tvm/contrib/cutlass/gen_tensor_op.py b/python/tvm/contrib/cutlass/gen_tensor_op.py index 2988f9a8a272..e2756b9f5610 100644 --- a/python/tvm/contrib/cutlass/gen_tensor_op.py +++ b/python/tvm/contrib/cutlass/gen_tensor_op.py @@ -483,7 +483,7 @@ def instantiate_template(func_name, annotations, func_args): if k in annotations: attrs[k] = annotations[k] - headers = [] + headers = ["tvm/runtime/registry.h"] if "relu" in func_name: headers.append("cutlass/epilogue/thread/linear_combination_bias_relu.h") diff --git a/python/tvm/contrib/cutlass/layer_norm_operation.py b/python/tvm/contrib/cutlass/layer_norm_operation.py index 589f559e935d..ad2e730d276c 100644 --- a/python/tvm/contrib/cutlass/layer_norm_operation.py +++ b/python/tvm/contrib/cutlass/layer_norm_operation.py @@ -39,6 +39,10 @@ def instantiate_layer_norm_template(attrs): cutlass::TensorRef _beta((data_type*)${beta}->data, layout_channels); cutlass::TensorRef _output((data_type*)out0->data, layout_2D); - cutlass::layernorm(size, _output, _input, _gamma, _beta, NULL); + auto func = tvm::runtime::Registry::Get("runtime.get_cuda_stream"); + ICHECK(func != nullptr); + cudaStream_t stream = static_cast((*func)().operator void*()); + + cutlass::layernorm(size, _output, _input, _gamma, _beta, stream); """ return substitute_template(template, attrs) diff --git a/python/tvm/contrib/cutlass/rms_norm_operation.py b/python/tvm/contrib/cutlass/rms_norm_operation.py index e24d6bc39ac0..ef0d8ef61f45 100644 --- a/python/tvm/contrib/cutlass/rms_norm_operation.py +++ b/python/tvm/contrib/cutlass/rms_norm_operation.py @@ -38,6 +38,10 @@ def instantiate_rms_norm_template(attrs): cutlass::TensorRef _weight((data_type*)${weight}->data, layout_channels); cutlass::TensorRef _output((data_type*)out0->data, layout_2D); - cutlass::rmsnorm(size, _output, _input, _weight, nullptr); + auto func = tvm::runtime::Registry::Get("runtime.get_cuda_stream"); + ICHECK(func != nullptr); + cudaStream_t stream = static_cast((*func)().operator void*()); + + cutlass::rmsnorm(size, _output, _input, _weight, stream); """ return substitute_template(template, attrs) diff --git a/src/relax/transform/rewrite_cuda_graph.cc b/src/relax/transform/rewrite_cuda_graph.cc index af1064063574..088fa38758af 100644 --- a/src/relax/transform/rewrite_cuda_graph.cc +++ b/src/relax/transform/rewrite_cuda_graph.cc @@ -150,8 +150,16 @@ class CUDAGraphRewritePlanner : public ExprVisitor { explicit CUDAGraphRewritePlanner(const IRModule& mod) : mod_(mod) {} std::vector Plan() { for (const auto& pair : mod_->functions) { - const auto& func = pair.second; - if (func->IsInstance()) { + if (pair.second->IsInstance()) { + // If a function has the num_input attribute, the last func->params.size() - num_inputs + // inputs are assumed to be fixed and thus they can be captured into a cuda graph. + static const char* attr_num_input = "num_input"; + const auto& func = Downcast(pair.second); + if (auto num_input = func->attrs.GetAttr(attr_num_input)) { + for (size_t i = num_input.value().IntValue(); i < func->params.size(); ++i) { + static_vars_.insert(func->params[i].get()); + } + } VisitExpr(func); } } @@ -349,7 +357,7 @@ class CUDAGraphRewritePlanner : public ExprVisitor { if (vars_collector != nullptr) { vars_collector->push_back(var); } - return static_bindings_.count(var); + return static_vars_.count(var); } if (const auto* shape = expr.as()) { @@ -402,7 +410,7 @@ class CUDAGraphRewritePlanner : public ExprVisitor { current_.capture_builder->AddBinding(binding); binding_to_region_[binding->var.get()] = current_.capture_builder; } - static_bindings_.emplace(binding->var.get(), GetRef(binding)); + static_vars_.emplace(binding->var.get()); } /*! \brief The states of the current scope (the BindingBlock) which is a pair of FuncBuilder. @@ -419,8 +427,8 @@ class CUDAGraphRewritePlanner : public ExprVisitor { IRModule mod_; // States of the current scope Scope current_; - // All the static bindings - std::unordered_map static_bindings_; + // Variables whose buffer address is fixed + std::unordered_set static_vars_; // Binding to the FuncBuilder if the binding is lifted. This is used to update the inputs/outputs // of the lifted function when its binding is used outside. std::unordered_map binding_to_region_; diff --git a/src/runtime/cuda/cuda_device_api.cc b/src/runtime/cuda/cuda_device_api.cc index 71788e52999a..b8854f88cbe1 100644 --- a/src/runtime/cuda/cuda_device_api.cc +++ b/src/runtime/cuda/cuda_device_api.cc @@ -300,5 +300,9 @@ TVM_DLL String GetCudaFreeMemory() { TVM_REGISTER_GLOBAL("runtime.GetCudaFreeMemory").set_body_typed(GetCudaFreeMemory); +TVM_REGISTER_GLOBAL("runtime.get_cuda_stream").set_body_typed([]() { + return static_cast(CUDAThreadEntry::ThreadLocal()->stream); +}); + } // namespace runtime } // namespace tvm diff --git a/src/runtime/relax_vm/cuda/cuda_graph_builtin.cc b/src/runtime/relax_vm/cuda/cuda_graph_builtin.cc index 9d2025d6477f..f6eef9ca259d 100644 --- a/src/runtime/relax_vm/cuda/cuda_graph_builtin.cc +++ b/src/runtime/relax_vm/cuda/cuda_graph_builtin.cc @@ -31,45 +31,22 @@ namespace tvm { namespace runtime { namespace relax_vm { -/*! \brief Represents a CUDA graph. */ -class CUDAGraphNode : public Object { - public: - cudaGraph_t handle_ = nullptr; - - ~CUDAGraphNode() { - if (handle_ != nullptr) { - cudaGraphDestroy(handle_); - } - } - - TVM_DECLARE_FINAL_OBJECT_INFO(CUDAGraphNode, Object); -}; - -/*! - * \brief Managed reference to CUDAGraphNode - * \sa CUDAGraphNode - */ -class CUDAGraph : public ObjectRef { - public: - explicit CUDAGraph(cudaGraph_t handle) { - auto n = make_object(); - n->handle_ = handle; - data_ = std::move(n); - } - TVM_DEFINE_OBJECT_REF_METHODS(CUDAGraph, ObjectRef, CUDAGraphNode); -}; - /*! \brief The cache states of a CUDA graph. */ class CUDAGraphCache : public Object { public: struct CaptureResult { + ~CaptureResult() { + if (exec) { + CUDA_CALL(cudaGraphExecDestroy(exec)); + } + } /*! * \brief Tuple of intemediate tensors in the capture func that will be used outside the * capture func */ ObjectRef states; - /*! \brief The cuda graph instance */ - CUDAGraph graph; + /*! \brief The instantiated cuda graph */ + cudaGraphExec_t exec = nullptr; }; static CUDAGraphCache* Get() { return dmlc::ThreadLocalStore::Get(); } @@ -88,11 +65,8 @@ class CUDAGraphCache : public Object { int64_t entry_index) { if (auto it = capture_cache_.find(entry_index); it != capture_cache_.end()) { // Launch CUDA graph - const auto& [states, cuda_graph] = it->second; - cudaGraphExec_t cuda_graph_exec; - CUDA_CALL(cudaGraphInstantiate(&cuda_graph_exec, cuda_graph->handle_, NULL, NULL, 0)); - CUDA_CALL(cudaGraphLaunch(cuda_graph_exec, CUDAThreadEntry::ThreadLocal()->stream)); - CUDA_CALL(cudaGraphExecDestroy(cuda_graph_exec)); + const auto& [states, exec] = it->second; + CUDA_CALL(cudaGraphLaunch(exec, CUDAThreadEntry::ThreadLocal()->stream)); return states; } @@ -129,9 +103,10 @@ class CUDAGraphCache : public Object { CUDA_CALL(cudaStreamEndCapture(CUDAThreadEntry::ThreadLocal()->stream, &graph)); std::swap(capture_stream, CUDAThreadEntry::ThreadLocal()->stream); - entry.graph = CUDAGraph(graph); capture_cache_[entry_index] = entry; + CUDA_CALL(cudaGraphInstantiate(&capture_cache_[entry_index].exec, graph, NULL, NULL, 0)); CUDA_CALL(cudaStreamDestroy(capture_stream)); + CUDA_CALL(cudaGraphDestroy(graph)); return entry.states; } diff --git a/tests/python/relax/test_codegen_cutlass.py b/tests/python/relax/test_codegen_cutlass.py index e1ce46ecb00e..1528141e4ab2 100644 --- a/tests/python/relax/test_codegen_cutlass.py +++ b/tests/python/relax/test_codegen_cutlass.py @@ -85,15 +85,23 @@ def main( pytestmark = [cutlass_enabled] -def build_and_run(mod, inputs_np, target, legalize=True): +def build_and_run(mod, inputs_np, target, legalize=True, cuda_graph=False): if legalize: mod = relax.transform.LegalizeOps()(mod) # For cpu reference, nop for cutlass. + with tvm.transform.PassContext(config={"relax.backend.use_cuda_graph": cuda_graph}): + ex = relax.build(mod, target) + dev = tvm.device(target, 0) - ex = relax.build(mod, target) vm = relax.VirtualMachine(ex, dev) f = vm["main"] inputs = [tvm.nd.array(inp, dev) for inp in inputs_np] + + # For cuda graph, run the compiled function twice to make sure that we can launch the cached + # graph on the second run. + if cuda_graph: + f(*inputs) + return f(*inputs).numpy() @@ -1554,5 +1562,63 @@ def main( tvm.testing.assert_allclose(out, ref, rtol=1e-2, atol=1e-2) +def test_conv2d_cuda_graph(): + @tvm.script.ir_module + class Conv2d: + @R.function + def main( + data: R.Tensor((16, 32, 32, 16), "float16"), + weight1: R.Tensor((16, 3, 3, 16), "float16"), + weight2: R.Tensor((16, 3, 3, 16), "float16"), + weight3: R.Tensor((16, 3, 3, 16), "float16"), + gamma: R.Tensor((16,), "float16"), + beta: R.Tensor((16,), "float16"), + ): + R.func_attr({"num_input": 1}) + with R.dataflow(): + conv1 = R.nn.relu( + R.nn.conv2d( + data, weight1, padding=(1, 1), data_layout="NHWC", kernel_layout="OHWI" + ) + ) + conv2 = R.nn.relu( + R.nn.conv2d( + conv1, weight2, padding=(1, 1), data_layout="NHWC", kernel_layout="OHWI" + ) + ) + ln = R.nn.layer_norm(conv2, gamma, beta, axes=[-1]) + conv3 = R.nn.relu( + R.nn.conv2d( + ln, weight3, padding=(1, 1), data_layout="NHWC", kernel_layout="OHWI" + ) + ) + R.output(conv3) + + return conv3 + + low, high = -1, 1 + data_shape = (16, 32, 32, 16) + weight_shape = (16, 3, 3, 16) + dtype = "float16" + data = np.random.randint(low, high, size=data_shape).astype(dtype) + weight1 = np.random.randint(low, high, size=weight_shape).astype(dtype) + weight2 = np.random.randint(low, high, size=weight_shape).astype(dtype) + weight3 = np.random.randint(low, high, size=weight_shape).astype(dtype) + gamma = np.random.randint(low, high, size=(weight_shape[0],)).astype(dtype) + beta = np.random.randint(low, high, size=(weight_shape[0],)).astype(dtype) + inputs = [data, weight1, weight2, weight3, gamma, beta] + + mod = partition_for_cutlass(Conv2d) + mod = relax.transform.RunCodegen({"cutlass": {"sm": 80, "find_first_valid": True}})(mod) + mod = relax.pipeline.get_pipeline()(mod) # pylint: disable=no-value-for-parameter + + with tvm.target.Target("cuda"): + mod = tvm.tir.transform.DefaultGPUSchedule()(mod) + + out = build_and_run(mod, inputs, "cuda", cuda_graph=True) + ref = build_and_run(Conv2d, inputs, "llvm", legalize=True) + tvm.testing.assert_allclose(out, ref, rtol=1e-2, atol=1e-2) + + if __name__ == "__main__": tvm.testing.main() diff --git a/tests/python/relax/test_transform_rewrite_cuda_graph.py b/tests/python/relax/test_transform_rewrite_cuda_graph.py index 52362eae013a..106147ef9af0 100644 --- a/tests/python/relax/test_transform_rewrite_cuda_graph.py +++ b/tests/python/relax/test_transform_rewrite_cuda_graph.py @@ -339,5 +339,327 @@ def main(x: R.Tensor((2, 4), dtype="float32")) -> R.Tensor((10,), dtype="float32 tvm.ir.assert_structural_equal(after, Expected) +def test_capture_fixed_inputs(): + @tvm.script.ir_module + class Conv2dx3: + @R.function + def main( + data: R.Tensor((16, 32, 32, 16), "float16"), + weight1: R.Tensor((16, 3, 3, 16), "float16"), + weight2: R.Tensor((16, 3, 3, 16), "float16"), + weight3: R.Tensor((16, 3, 3, 16), "float16"), + gamma: R.Tensor((16,), "float16"), + beta: R.Tensor((16,), "float16"), + ): + R.func_attr({"num_input": 1}) + with R.dataflow(): + conv1 = R.nn.relu( + R.nn.conv2d( + data, weight1, padding=(1, 1), data_layout="NHWC", kernel_layout="OHWI" + ) + ) + + ############################################################################### + # The second conv2d and layer norm can be captured into a graph + conv2 = R.nn.relu( + R.nn.conv2d( + conv1, weight2, padding=(1, 1), data_layout="NHWC", kernel_layout="OHWI" + ) + ) + ln = R.nn.layer_norm(conv2, gamma, beta, axes=[-1]) + ############################################################################### + + conv3 = R.nn.relu( + R.nn.conv2d( + ln, weight3, padding=(1, 1), data_layout="NHWC", kernel_layout="OHWI" + ) + ) + R.output(conv3) + + return conv3 + + @I.ir_module + class Expected: + @T.prim_func + def fused_conv2d_relu( + data: T.Buffer((T.int64(16), T.int64(32), T.int64(32), T.int64(16)), "float16"), + weight1: T.Buffer((T.int64(16), T.int64(3), T.int64(3), T.int64(16)), "float16"), + var_compute_intermediate: T.Buffer( + (T.int64(16), T.int64(32), T.int64(32), T.int64(16)), "float16" + ), + ): + T.func_attr({"tir.noalias": T.bool(True)}) + # with T.block("root"): + pad_temp = T.alloc_buffer( + (T.int64(16), T.int64(34), T.int64(34), T.int64(16)), "float16" + ) + var_conv2d_nhwc_intermediate = T.alloc_buffer( + (T.int64(16), T.int64(32), T.int64(32), T.int64(16)), "float16" + ) + for i0, i1, i2, i3 in T.grid(T.int64(16), T.int64(34), T.int64(34), T.int64(16)): + with T.block("pad_temp"): + v_i0, v_i1, v_i2, v_i3 = T.axis.remap("SSSS", [i0, i1, i2, i3]) + T.reads(data[v_i0, v_i1 - T.int64(1), v_i2 - T.int64(1), v_i3]) + T.writes(pad_temp[v_i0, v_i1, v_i2, v_i3]) + pad_temp[v_i0, v_i1, v_i2, v_i3] = T.if_then_else( + T.int64(1) <= v_i1 + and v_i1 < T.int64(33) + and T.int64(1) <= v_i2 + and v_i2 < T.int64(33), + data[v_i0, v_i1 - T.int64(1), v_i2 - T.int64(1), v_i3], + T.float16(0), + ) + for nn, yy, xx, ff, ry, rx, rc in T.grid( + T.int64(16), + T.int64(32), + T.int64(32), + T.int64(16), + T.int64(3), + T.int64(3), + T.int64(16), + ): + with T.block("conv2d_nhwc"): + v_nn, v_yy, v_xx, v_ff, v_ry, v_rx, v_rc = T.axis.remap( + "SSSSRRR", [nn, yy, xx, ff, ry, rx, rc] + ) + T.reads( + pad_temp[v_nn, v_yy + v_ry, v_xx + v_rx, v_rc], + weight1[v_ff, v_ry, v_rx, v_rc], + ) + T.writes(var_conv2d_nhwc_intermediate[v_nn, v_yy, v_xx, v_ff]) + with T.init(): + var_conv2d_nhwc_intermediate[v_nn, v_yy, v_xx, v_ff] = T.float16(0) + var_conv2d_nhwc_intermediate[v_nn, v_yy, v_xx, v_ff] = ( + var_conv2d_nhwc_intermediate[v_nn, v_yy, v_xx, v_ff] + + pad_temp[v_nn, v_yy + v_ry, v_xx + v_rx, v_rc] + * weight1[v_ff, v_ry, v_rx, v_rc] + ) + for i0, i1, i2, i3 in T.grid(T.int64(16), T.int64(32), T.int64(32), T.int64(16)): + with T.block("compute"): + v_i0, v_i1, v_i2, v_i3 = T.axis.remap("SSSS", [i0, i1, i2, i3]) + T.reads(var_conv2d_nhwc_intermediate[v_i0, v_i1, v_i2, v_i3]) + T.writes(var_compute_intermediate[v_i0, v_i1, v_i2, v_i3]) + var_compute_intermediate[v_i0, v_i1, v_i2, v_i3] = T.max( + var_conv2d_nhwc_intermediate[v_i0, v_i1, v_i2, v_i3], T.float16(0) + ) + + @T.prim_func + def layer_norm( + A: T.Buffer((T.int64(16), T.int64(32), T.int64(32), T.int64(16)), "float16"), + B: T.Buffer((T.int64(16),), "float16"), + C: T.Buffer((T.int64(16),), "float16"), + T_layer_norm: T.Buffer((T.int64(16), T.int64(32), T.int64(32), T.int64(16)), "float16"), + ): + T.func_attr({"op_pattern": 4, "tir.noalias": T.bool(True)}) + # with T.block("root"): + A_red_temp_v0 = T.alloc_buffer((T.int64(16), T.int64(32), T.int64(32))) + A_red_temp_v1 = T.alloc_buffer((T.int64(16), T.int64(32), T.int64(32))) + for ax0, ax1, ax2, k3 in T.grid(T.int64(16), T.int64(32), T.int64(32), T.int64(16)): + with T.block("A_red_temp"): + v_ax0, v_ax1, v_ax2, v_k3 = T.axis.remap("SSSR", [ax0, ax1, ax2, k3]) + T.reads(A[v_ax0, v_ax1, v_ax2, v_k3]) + T.writes(A_red_temp_v0[v_ax0, v_ax1, v_ax2], A_red_temp_v1[v_ax0, v_ax1, v_ax2]) + with T.init(): + A_red_temp_v0[v_ax0, v_ax1, v_ax2] = T.float32(0) + A_red_temp_v1[v_ax0, v_ax1, v_ax2] = T.float32(0) + v_A_red_temp_v0: T.float32 = A_red_temp_v0[v_ax0, v_ax1, v_ax2] + T.Cast( + "float32", A[v_ax0, v_ax1, v_ax2, v_k3] + ) + v_A_red_temp_v1: T.float32 = A_red_temp_v1[v_ax0, v_ax1, v_ax2] + T.Cast( + "float32", A[v_ax0, v_ax1, v_ax2, v_k3] + ) * T.Cast("float32", A[v_ax0, v_ax1, v_ax2, v_k3]) + A_red_temp_v0[v_ax0, v_ax1, v_ax2] = v_A_red_temp_v0 + A_red_temp_v1[v_ax0, v_ax1, v_ax2] = v_A_red_temp_v1 + for ax0, ax1, ax2, ax3 in T.grid(T.int64(16), T.int64(32), T.int64(32), T.int64(16)): + with T.block("T_layer_norm"): + v_ax0, v_ax1, v_ax2, v_ax3 = T.axis.remap("SSSS", [ax0, ax1, ax2, ax3]) + T.reads( + A[v_ax0, v_ax1, v_ax2, v_ax3], + A_red_temp_v0[v_ax0, v_ax1, v_ax2], + A_red_temp_v1[v_ax0, v_ax1, v_ax2], + B[v_ax3], + C[v_ax3], + ) + T.writes(T_layer_norm[v_ax0, v_ax1, v_ax2, v_ax3]) + T_layer_norm[v_ax0, v_ax1, v_ax2, v_ax3] = ( + T.Cast( + "float16", + ( + T.Cast("float32", A[v_ax0, v_ax1, v_ax2, v_ax3]) + - A_red_temp_v0[v_ax0, v_ax1, v_ax2] * T.float32(0.0625) + ) + * T.rsqrt( + A_red_temp_v1[v_ax0, v_ax1, v_ax2] * T.float32(0.0625) + - A_red_temp_v0[v_ax0, v_ax1, v_ax2] + * T.float32(0.0625) + * (A_red_temp_v0[v_ax0, v_ax1, v_ax2] * T.float32(0.0625)) + + T.float32(1.0000000000000001e-05) + ), + ) + * B[v_ax3] + + C[v_ax3] + ) + + @R.function(private=True) + def cuda_graph_alloc() -> R.Tuple(R.Object, R.Object): + R.func_attr({"relax.force_pure": True}) + storage: R.Object = R.memory.alloc_storage( + R.shape([524288]), R.prim_value(0), R.str("global"), R.dtype("float16") + ) + storage1: R.Object = R.memory.alloc_storage( + R.shape([524288]), R.prim_value(0), R.str("global"), R.dtype("float16") + ) + gv: R.Tuple(R.Object, R.Object) = storage, storage1 + return gv + + @R.function(private=True) + def cuda_graph_capture( + lv: R.Tensor((16, 32, 32, 16), dtype="float16"), + lv1: R.Tensor((16, 3, 3, 16), dtype="float16"), + alloc1: R.Tensor((16, 32, 32, 16), dtype="float16"), + alloc: R.Tensor((16, 32, 32, 16), dtype="float16"), + params: R.Tuple( + R.Tensor((16, 3, 3, 16), dtype="float16"), + R.Tensor((16, 3, 3, 16), dtype="float16"), + R.Tensor((16, 3, 3, 16), dtype="float16"), + R.Tensor((16,), dtype="float16"), + R.Tensor((16,), dtype="float16"), + ), + storage: R.Object, + ) -> R.Tuple( + R.Tensor((16, 32, 32, 16), dtype="float16"), + R.Tensor((16, 3, 3, 16), dtype="float16"), + R.Tensor((16, 32, 32, 16), dtype="float16"), + ): + R.func_attr({"relax.force_pure": True}) + cls = Expected + _1: R.Tuple = cls.fused_conv2d_relu(lv, lv1, alloc1) + _: R.Tuple = R.memory.kill_tensor(alloc) + lv1_1: R.Tensor((16, 32, 32, 16), dtype="float16") = alloc1 + lv2: R.Tensor((16,), dtype="float16") = params[3] + lv3: R.Tensor((16,), dtype="float16") = params[4] + alloc2: R.Tensor((16, 32, 32, 16), dtype="float16") = R.memory.alloc_tensor( + storage, R.prim_value(0), R.shape([16, 32, 32, 16]), R.dtype("float16") + ) + _2: R.Tuple = cls.layer_norm(lv1_1, lv2, lv3, alloc2) + _1_1: R.Tuple = R.memory.kill_tensor(alloc1) + ln: R.Tensor((16, 32, 32, 16), dtype="float16") = alloc2 + lv4: R.Tensor((16, 3, 3, 16), dtype="float16") = params[2] + gv: R.Tuple( + R.Tensor((16, 32, 32, 16), dtype="float16"), + R.Tensor((16, 3, 3, 16), dtype="float16"), + R.Tensor((16, 32, 32, 16), dtype="float16"), + ) = (ln, lv4, alloc2) + return gv + + @R.function + def main_transform_params( + params: R.Tuple( + R.Tensor((16, 3, 3, 16), dtype="float16"), + R.Tensor((16, 3, 3, 16), dtype="float16"), + R.Tensor((16, 3, 3, 16), dtype="float16"), + R.Tensor((16,), dtype="float16"), + R.Tensor((16,), dtype="float16"), + ) + ) -> R.Tuple( + R.Tensor((16, 3, 3, 16), dtype="float16"), + R.Tensor((16, 3, 3, 16), dtype="float16"), + R.Tensor((16, 3, 3, 16), dtype="float16"), + R.Tensor((16,), dtype="float16"), + R.Tensor((16,), dtype="float16"), + ): + R.func_attr({"relax.force_pure": True}) + lv: R.Tensor((16, 3, 3, 16), dtype="float16") = params[0] + lv1: R.Tensor((16, 3, 3, 16), dtype="float16") = params[1] + lv2: R.Tensor((16, 3, 3, 16), dtype="float16") = params[2] + lv3: R.Tensor((16,), dtype="float16") = params[3] + lv4: R.Tensor((16,), dtype="float16") = params[4] + gv: R.Tuple( + R.Tensor((16, 3, 3, 16), dtype="float16"), + R.Tensor((16, 3, 3, 16), dtype="float16"), + R.Tensor((16, 3, 3, 16), dtype="float16"), + R.Tensor((16,), dtype="float16"), + R.Tensor((16,), dtype="float16"), + ) = (lv, lv1, lv2, lv3, lv4) + return gv + + @R.function + def main( + data: R.Tensor((16, 32, 32, 16), dtype="float16"), + params: R.Tuple( + R.Tensor((16, 3, 3, 16), dtype="float16"), + R.Tensor((16, 3, 3, 16), dtype="float16"), + R.Tensor((16, 3, 3, 16), dtype="float16"), + R.Tensor((16,), dtype="float16"), + R.Tensor((16,), dtype="float16"), + ), + ) -> R.Tensor((16, 32, 32, 16), dtype="float16"): + R.func_attr({"num_input": 1, "relax.force_pure": True}) + cls = Expected + lv: R.Tensor((16, 3, 3, 16), dtype="float16") = params[0] + gv: R.Tuple(R.Object, R.Object) = R.call_builtin_with_ctx( + "vm.builtin.cuda_graph.get_cached_alloc", + (cls.cuda_graph_alloc, R.prim_value(0)), + sinfo_args=(R.Tuple(R.Object, R.Object),), + ) + storage: R.Object = gv[0] + alloc: R.Tensor((16, 32, 32, 16), dtype="float16") = R.memory.alloc_tensor( + storage, R.prim_value(0), R.shape([16, 32, 32, 16]), R.dtype("float16") + ) + _: R.Tuple = cls.fused_conv2d_relu(data, lv, alloc) + lv_1: R.Tensor((16, 32, 32, 16), dtype="float16") = alloc + lv1: R.Tensor((16, 3, 3, 16), dtype="float16") = params[1] + storage1: R.Object = gv[1] + alloc1: R.Tensor((16, 32, 32, 16), dtype="float16") = R.memory.alloc_tensor( + storage1, R.prim_value(0), R.shape([16, 32, 32, 16]), R.dtype("float16") + ) + gv1: R.Tuple( + R.Tensor((16, 32, 32, 16), dtype="float16"), + R.Tensor((16, 3, 3, 16), dtype="float16"), + R.Tensor((16, 32, 32, 16), dtype="float16"), + ) = R.call_builtin_with_ctx( + "vm.builtin.cuda_graph.run_or_capture", + ( + cls.cuda_graph_capture, + (lv_1, lv1, alloc1, alloc, params, storage), + R.prim_value(0), + ), + sinfo_args=( + R.Tuple( + R.Tensor((16, 32, 32, 16), dtype="float16"), + R.Tensor((16, 3, 3, 16), dtype="float16"), + R.Tensor((16, 32, 32, 16), dtype="float16"), + ), + ), + ) + alloc2: R.Tensor((16, 32, 32, 16), dtype="float16") = gv1[2] + ln: R.Tensor((16, 32, 32, 16), dtype="float16") = gv1[0] + lv4: R.Tensor((16, 3, 3, 16), dtype="float16") = gv1[1] + alloc3: R.Tensor((16, 32, 32, 16), dtype="float16") = R.builtin.alloc_tensor( + R.shape([16, 32, 32, 16]), R.dtype("float16"), R.prim_value(0) + ) + _3: R.Tuple = cls.fused_conv2d_relu(ln, lv4, alloc3) + _2: R.Tuple = R.memory.kill_tensor(alloc2) + gv_1: R.Tensor((16, 32, 32, 16), dtype="float16") = alloc3 + _3_1: R.Tuple = R.memory.kill_storage(storage) + _4: R.Tuple = R.memory.kill_storage(storage1) + return gv_1 + + mod = tvm.transform.Sequential( + [ + relax.pipeline.get_pipeline(), + relax.transform.LiftTransformParams(), + relax.transform.ToNonDataflow(), + relax.transform.RemovePurityChecking(), + relax.transform.CallTIRRewrite(), + relax.transform.StaticPlanBlockMemory(), + ] + )(Conv2dx3) + + mod["main"] = mod["main"].with_attr({"num_input": 1}) + after = relax.transform.RewriteCUDAGraph()(mod) + tvm.ir.assert_structural_equal(after, after) + + if __name__ == "__main__": tvm.testing.main()