Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 5 additions & 1 deletion python/tvm/contrib/cutlass/attention_operation.py
Original file line number Diff line number Diff line change
Expand Up @@ -139,7 +139,11 @@ def instantiate_attention_template(attrs):
}

CHECK(Attention::check_supported(p));
kernel_fn<<<p.getBlocksGrid(), p.getThreadsGrid(), smem_bytes>>>(p);
auto func = tvm::runtime::Registry::Get("runtime.get_cuda_stream");
ICHECK(func != nullptr);
cudaStream_t stream = static_cast<cudaStream_t>((*func)().operator void*());

kernel_fn<<<p.getBlocksGrid(), p.getThreadsGrid(), smem_bytes, stream>>>(p);

if (accumulator_buf_allocated) {
cudaFree(p.output_accum_ptr);
Expand Down
7 changes: 6 additions & 1 deletion python/tvm/contrib/cutlass/conv2d_operation.py
Original file line number Diff line number Diff line change
Expand Up @@ -423,7 +423,12 @@ def instantiate_conv2d_template(attrs):
status = conv2d_op.initialize(arguments, workspace.get());
CHECK(status == cutlass::Status::kSuccess);
${split_k_update}
status = conv2d_op();

auto func = tvm::runtime::Registry::Get("runtime.get_cuda_stream");
ICHECK(func != nullptr);
cudaStream_t stream = static_cast<cudaStream_t>((*func)().operator void*());

status = conv2d_op(stream);
CHECK(status == cutlass::Status::kSuccess);
${split_k_reduction}
"""
Expand Down
47 changes: 24 additions & 23 deletions python/tvm/contrib/cutlass/gemm_operation.py
Original file line number Diff line number Diff line change
Expand Up @@ -344,7 +344,12 @@ def instantiate_gemm_template(attrs):
CHECK(status == cutlass::Status::kSuccess);
status = gemm_op.initialize(arguments, workspace.get());
CHECK(status == cutlass::Status::kSuccess);
status = gemm_op();

auto func = tvm::runtime::Registry::Get("runtime.get_cuda_stream");
ICHECK(func != nullptr);
cudaStream_t stream = static_cast<cudaStream_t>((*func)().operator void*());

status = gemm_op(stream);
CHECK(status == cutlass::Status::kSuccess);
"""
op_type = attrs["op_type"]
Expand Down Expand Up @@ -416,38 +421,34 @@ def emit_fp16A_int4B_matmul(attrs):
int m = ${A_arg}->shape[${batch_offset}];
int n = ${B_arg}->shape[1] * 2;
int k = ${B_arg}->shape[0];

auto func = tvm::runtime::Registry::Get("runtime.get_cuda_stream");
ICHECK(func != nullptr);
cudaStream_t stream = static_cast<cudaStream_t>((*func)().operator void*());
""",
attrs,
)

template = """
${template_common}
gemm_fp16_int4(static_cast<cutlass::half_t*>(${A_arg}->data),
static_cast<cutlass::uint4b_t*>(${B_arg}->data),
static_cast<cutlass::half_t*>(${scales_arg}->data),
static_cast<cutlass::half_t*>(out0->data),
m, n, k, nullptr, 0, nullptr);
"""

template_bias = """
${template_common}
gemm_fp16_int4_bias(static_cast<cutlass::half_t*>(${A_arg}->data),
static_cast<cutlass::uint4b_t*>(${B_arg}->data),
static_cast<cutlass::half_t*>(${scales_arg}->data),
static_cast<cutlass::half_t*>(${bias_arg}->data),
static_cast<cutlass::half_t*>(out0->data),
m, n, k, nullptr, 0, nullptr);
gemm_fp16_int_bias_act(static_cast<cutlass::half_t*>(${A_arg}->data),
static_cast<${weight_dtype}*>(${B_arg}->data),
static_cast<cutlass::half_t*>(${scales_arg}->data),
${bias},
static_cast<cutlass::half_t*>(out0->data),
"${activation}",
m, n, k, ${bias_stride}, nullptr, 0, stream);
"""

template_residual = """
${template_common}
gemm_fp16_int4_bias_act_residual(static_cast<cutlass::half_t*>(${A_arg}->data),
static_cast<cutlass::uint4b_t*>(${B_arg}->data),
static_cast<cutlass::half_t*>(${scales_arg}->data),
${bias},
static_cast<cutlass::half_t*>(${residual_arg}->data),
static_cast<cutlass::half_t*>(out0->data), "${activation}", "${binary_op}", "${unary_op}",
m, n, k, nullptr, 0, nullptr);
gemm_fp16_int_bias_act_residual(static_cast<cutlass::half_t*>(${A_arg}->data),
static_cast<${weight_dtype}*>(${B_arg}->data),
static_cast<cutlass::half_t*>(${scales_arg}->data),
${bias},
static_cast<cutlass::half_t*>(${residual_arg}->data),
static_cast<cutlass::half_t*>(out0->data), "${activation}", "${binary_op}", "${unary_op}",
m, n, k, nullptr, 0, stream);
"""

if "residual_arg" in attrs and "bias_arg" in attrs:
Expand Down
2 changes: 1 addition & 1 deletion python/tvm/contrib/cutlass/gen_tensor_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -483,7 +483,7 @@ def instantiate_template(func_name, annotations, func_args):
if k in annotations:
attrs[k] = annotations[k]

headers = []
headers = ["tvm/runtime/registry.h"]

if "relu" in func_name:
headers.append("cutlass/epilogue/thread/linear_combination_bias_relu.h")
Expand Down
6 changes: 5 additions & 1 deletion python/tvm/contrib/cutlass/layer_norm_operation.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,10 @@ def instantiate_layer_norm_template(attrs):
cutlass::TensorRef<data_type, RowMajor> _beta((data_type*)${beta}->data, layout_channels);
cutlass::TensorRef<data_type, RowMajor> _output((data_type*)out0->data, layout_2D);

cutlass::layernorm(size, _output, _input, _gamma, _beta, NULL);
auto func = tvm::runtime::Registry::Get("runtime.get_cuda_stream");
ICHECK(func != nullptr);
cudaStream_t stream = static_cast<cudaStream_t>((*func)().operator void*());

cutlass::layernorm(size, _output, _input, _gamma, _beta, stream);
"""
return substitute_template(template, attrs)
6 changes: 5 additions & 1 deletion python/tvm/contrib/cutlass/rms_norm_operation.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,10 @@ def instantiate_rms_norm_template(attrs):
cutlass::TensorRef<data_type, RowMajor> _weight((data_type*)${weight}->data, layout_channels);
cutlass::TensorRef<data_type, RowMajor> _output((data_type*)out0->data, layout_2D);

cutlass::rmsnorm(size, _output, _input, _weight, nullptr);
auto func = tvm::runtime::Registry::Get("runtime.get_cuda_stream");
ICHECK(func != nullptr);
cudaStream_t stream = static_cast<cudaStream_t>((*func)().operator void*());

cutlass::rmsnorm(size, _output, _input, _weight, stream);
"""
return substitute_template(template, attrs)
20 changes: 14 additions & 6 deletions src/relax/transform/rewrite_cuda_graph.cc
Original file line number Diff line number Diff line change
Expand Up @@ -150,8 +150,16 @@ class CUDAGraphRewritePlanner : public ExprVisitor {
explicit CUDAGraphRewritePlanner(const IRModule& mod) : mod_(mod) {}
std::vector<LiftedFunctionRewritePlan> Plan() {
for (const auto& pair : mod_->functions) {
const auto& func = pair.second;
if (func->IsInstance<FunctionNode>()) {
if (pair.second->IsInstance<FunctionNode>()) {
// If a function has the num_input attribute, the last func->params.size() - num_inputs
// inputs are assumed to be fixed and thus they can be captured into a cuda graph.
static const char* attr_num_input = "num_input";
const auto& func = Downcast<Function>(pair.second);
if (auto num_input = func->attrs.GetAttr<Integer>(attr_num_input)) {
for (size_t i = num_input.value().IntValue(); i < func->params.size(); ++i) {
static_vars_.insert(func->params[i].get());
}
}
VisitExpr(func);
}
}
Expand Down Expand Up @@ -349,7 +357,7 @@ class CUDAGraphRewritePlanner : public ExprVisitor {
if (vars_collector != nullptr) {
vars_collector->push_back(var);
}
return static_bindings_.count(var);
return static_vars_.count(var);
}

if (const auto* shape = expr.as<ShapeExprNode>()) {
Expand Down Expand Up @@ -402,7 +410,7 @@ class CUDAGraphRewritePlanner : public ExprVisitor {
current_.capture_builder->AddBinding(binding);
binding_to_region_[binding->var.get()] = current_.capture_builder;
}
static_bindings_.emplace(binding->var.get(), GetRef<VarBinding>(binding));
static_vars_.emplace(binding->var.get());
}

/*! \brief The states of the current scope (the BindingBlock) which is a pair of FuncBuilder.
Expand All @@ -419,8 +427,8 @@ class CUDAGraphRewritePlanner : public ExprVisitor {
IRModule mod_;
// States of the current scope
Scope current_;
// All the static bindings
std::unordered_map<const VarNode*, VarBinding> static_bindings_;
// Variables whose buffer address is fixed
std::unordered_set<const VarNode*> static_vars_;
// Binding to the FuncBuilder if the binding is lifted. This is used to update the inputs/outputs
// of the lifted function when its binding is used outside.
std::unordered_map<const VarNode*, FuncBuilder*> binding_to_region_;
Expand Down
4 changes: 4 additions & 0 deletions src/runtime/cuda/cuda_device_api.cc
Original file line number Diff line number Diff line change
Expand Up @@ -300,5 +300,9 @@ TVM_DLL String GetCudaFreeMemory() {

TVM_REGISTER_GLOBAL("runtime.GetCudaFreeMemory").set_body_typed(GetCudaFreeMemory);

TVM_REGISTER_GLOBAL("runtime.get_cuda_stream").set_body_typed([]() {
return static_cast<void*>(CUDAThreadEntry::ThreadLocal()->stream);
});

} // namespace runtime
} // namespace tvm
47 changes: 11 additions & 36 deletions src/runtime/relax_vm/cuda/cuda_graph_builtin.cc
Original file line number Diff line number Diff line change
Expand Up @@ -31,45 +31,22 @@ namespace tvm {
namespace runtime {
namespace relax_vm {

/*! \brief Represents a CUDA graph. */
class CUDAGraphNode : public Object {
public:
cudaGraph_t handle_ = nullptr;

~CUDAGraphNode() {
if (handle_ != nullptr) {
cudaGraphDestroy(handle_);
}
}

TVM_DECLARE_FINAL_OBJECT_INFO(CUDAGraphNode, Object);
};

/*!
* \brief Managed reference to CUDAGraphNode
* \sa CUDAGraphNode
*/
class CUDAGraph : public ObjectRef {
public:
explicit CUDAGraph(cudaGraph_t handle) {
auto n = make_object<CUDAGraphNode>();
n->handle_ = handle;
data_ = std::move(n);
}
TVM_DEFINE_OBJECT_REF_METHODS(CUDAGraph, ObjectRef, CUDAGraphNode);
};

/*! \brief The cache states of a CUDA graph. */
class CUDAGraphCache : public Object {
public:
struct CaptureResult {
~CaptureResult() {
if (exec) {
CUDA_CALL(cudaGraphExecDestroy(exec));
}
}
/*!
* \brief Tuple of intemediate tensors in the capture func that will be used outside the
* capture func
*/
ObjectRef states;
/*! \brief The cuda graph instance */
CUDAGraph graph;
/*! \brief The instantiated cuda graph */
cudaGraphExec_t exec = nullptr;
};

static CUDAGraphCache* Get() { return dmlc::ThreadLocalStore<CUDAGraphCache>::Get(); }
Expand All @@ -88,11 +65,8 @@ class CUDAGraphCache : public Object {
int64_t entry_index) {
if (auto it = capture_cache_.find(entry_index); it != capture_cache_.end()) {
// Launch CUDA graph
const auto& [states, cuda_graph] = it->second;
cudaGraphExec_t cuda_graph_exec;
CUDA_CALL(cudaGraphInstantiate(&cuda_graph_exec, cuda_graph->handle_, NULL, NULL, 0));
CUDA_CALL(cudaGraphLaunch(cuda_graph_exec, CUDAThreadEntry::ThreadLocal()->stream));
CUDA_CALL(cudaGraphExecDestroy(cuda_graph_exec));
const auto& [states, exec] = it->second;
CUDA_CALL(cudaGraphLaunch(exec, CUDAThreadEntry::ThreadLocal()->stream));
return states;
}

Expand Down Expand Up @@ -129,9 +103,10 @@ class CUDAGraphCache : public Object {
CUDA_CALL(cudaStreamEndCapture(CUDAThreadEntry::ThreadLocal()->stream, &graph));
std::swap(capture_stream, CUDAThreadEntry::ThreadLocal()->stream);

entry.graph = CUDAGraph(graph);
capture_cache_[entry_index] = entry;
CUDA_CALL(cudaGraphInstantiate(&capture_cache_[entry_index].exec, graph, NULL, NULL, 0));
CUDA_CALL(cudaStreamDestroy(capture_stream));
CUDA_CALL(cudaGraphDestroy(graph));
return entry.states;
}

Expand Down
70 changes: 68 additions & 2 deletions tests/python/relax/test_codegen_cutlass.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,15 +85,23 @@ def main(
pytestmark = [cutlass_enabled]


def build_and_run(mod, inputs_np, target, legalize=True):
def build_and_run(mod, inputs_np, target, legalize=True, cuda_graph=False):
if legalize:
mod = relax.transform.LegalizeOps()(mod) # For cpu reference, nop for cutlass.

with tvm.transform.PassContext(config={"relax.backend.use_cuda_graph": cuda_graph}):
ex = relax.build(mod, target)

dev = tvm.device(target, 0)
ex = relax.build(mod, target)
vm = relax.VirtualMachine(ex, dev)
f = vm["main"]
inputs = [tvm.nd.array(inp, dev) for inp in inputs_np]

# For cuda graph, run the compiled function twice to make sure that we can launch the cached
# graph on the second run.
if cuda_graph:
f(*inputs)

return f(*inputs).numpy()


Expand Down Expand Up @@ -1554,5 +1562,63 @@ def main(
tvm.testing.assert_allclose(out, ref, rtol=1e-2, atol=1e-2)


def test_conv2d_cuda_graph():
@tvm.script.ir_module
class Conv2d:
@R.function
def main(
data: R.Tensor((16, 32, 32, 16), "float16"),
weight1: R.Tensor((16, 3, 3, 16), "float16"),
weight2: R.Tensor((16, 3, 3, 16), "float16"),
weight3: R.Tensor((16, 3, 3, 16), "float16"),
gamma: R.Tensor((16,), "float16"),
beta: R.Tensor((16,), "float16"),
):
R.func_attr({"num_input": 1})
with R.dataflow():
conv1 = R.nn.relu(
R.nn.conv2d(
data, weight1, padding=(1, 1), data_layout="NHWC", kernel_layout="OHWI"
)
)
conv2 = R.nn.relu(
R.nn.conv2d(
conv1, weight2, padding=(1, 1), data_layout="NHWC", kernel_layout="OHWI"
)
)
ln = R.nn.layer_norm(conv2, gamma, beta, axes=[-1])
conv3 = R.nn.relu(
R.nn.conv2d(
ln, weight3, padding=(1, 1), data_layout="NHWC", kernel_layout="OHWI"
)
)
R.output(conv3)

return conv3

low, high = -1, 1
data_shape = (16, 32, 32, 16)
weight_shape = (16, 3, 3, 16)
dtype = "float16"
data = np.random.randint(low, high, size=data_shape).astype(dtype)
weight1 = np.random.randint(low, high, size=weight_shape).astype(dtype)
weight2 = np.random.randint(low, high, size=weight_shape).astype(dtype)
weight3 = np.random.randint(low, high, size=weight_shape).astype(dtype)
gamma = np.random.randint(low, high, size=(weight_shape[0],)).astype(dtype)
beta = np.random.randint(low, high, size=(weight_shape[0],)).astype(dtype)
inputs = [data, weight1, weight2, weight3, gamma, beta]

mod = partition_for_cutlass(Conv2d)
mod = relax.transform.RunCodegen({"cutlass": {"sm": 80, "find_first_valid": True}})(mod)
mod = relax.pipeline.get_pipeline()(mod) # pylint: disable=no-value-for-parameter

with tvm.target.Target("cuda"):
mod = tvm.tir.transform.DefaultGPUSchedule()(mod)

out = build_and_run(mod, inputs, "cuda", cuda_graph=True)
ref = build_and_run(Conv2d, inputs, "llvm", legalize=True)
tvm.testing.assert_allclose(out, ref, rtol=1e-2, atol=1e-2)


if __name__ == "__main__":
tvm.testing.main()
Loading