diff --git a/include/tvm/runtime/profiling.h b/include/tvm/runtime/profiling.h index 366f4f1deed1..606bf502c195 100644 --- a/include/tvm/runtime/profiling.h +++ b/include/tvm/runtime/profiling.h @@ -477,6 +477,40 @@ String ShapeString(NDArray shape, DLDataType dtype); */ String ShapeString(const std::vector& shape, DLDataType dtype); +/*! \brief Collect performance information of a function execution. Usually + * used with a compiled PrimFunc (via tvm.build). + * + * This information can include performance counters like cache hits and FLOPs + * that are useful in debugging performance issues of individual PrimFuncs. + * Different metrics can be collected depending on which MetricCollector is + * used. + * + * Example usage: + * \code{.cpp} + * // Use PAPI to measure the number of floating point operations. + * PackedFunc profiler = ProfileModule( + * mod, "main", kDLCPU, 0, {CreatePAPIMetricCollector({{kDLCPU, 0}, {"PAPI_FP_OPS"}})}); + * Report r = profiler(arg1, arg2, arg); + * std::cout << r << std::endl; + * \endcode + * + * \param mod Module to profile. Usually a PrimFunc that has been compiled to machine code. + * \param func_name Name of function to run in the module. + * \param device_type Device type to run on. Profiling will include performance + * metrics specific to this device type. + * \param device_id Id of device to run on. + * \param warmup_iters Number of iterations of the function to run before collecting + * performance information. Recommend to set this larger + * than 0 so that cache effects are consistent. + * \param collectors List of different + * ways to collect metrics. See MetricCollector. + * \returns A PackedFunc which takes the same arguments as the `mod[func_name]` + * and returns performance metrics as a `Map` where + * values can be `CountNode`, `DurationNode`, `PercentNode`. + */ +PackedFunc ProfileFunction(Module mod, std::string func_name, int device_type, int device_id, + int warmup_iters, Array collectors); + } // namespace profiling } // namespace runtime } // namespace tvm diff --git a/python/tvm/runtime/profiling/__init__.py b/python/tvm/runtime/profiling/__init__.py index 7d40a81e498a..86145ce6242f 100644 --- a/python/tvm/runtime/profiling/__init__.py +++ b/python/tvm/runtime/profiling/__init__.py @@ -163,6 +163,56 @@ def __init__(self, dev: Device): self.__init_handle_by_constructor__(_ffi_api.DeviceWrapper, dev) +def profile_function(mod, dev, collectors, func_name="main", warmup_iters=10): + """Collect performance information of a function execution. Usually used with + a compiled PrimFunc. + + This information can include performance counters like cache hits and FLOPs + that are useful in debugging performance issues of individual PrimFuncs. + Different metrics can be collected depending on which MetricCollector is + used. + + Example + ------- + + .. code-block: python + f = tvm.build(my_func, target="llvm", name="my_func") + prof = tvm.runtime.profiling.profile_function( + f, + tvm.cpu(), + [tvm.runtime.profiling.PAPIMetricCollector({tvm.cpu(): ["PAPI_FP_OPS"]}), + ) + counters = prof(*args) + print(counters) + + Parameters + ---------- + mod: Module + Module containing the function to profile. + dev: Device + Device to run the function on. + + collectors: List[MetricCollector] + :py:class:`MetricCollector`s which will collect performance information. + func_name: str + Name of the function in `mod` to profile. Defaults to "main". + warmup_iters: int + Number of iterations to run the function before collecting performance + information. Recommended to set this larger than 0 for consistent cache + effects. Defaults to 10. + + Returns + ------- + prof: PackedFunc[args, Dict[str, ObjectRef]] + PackedFunc which takes the same arguments as the `mod[func_name]` and + returns performance metrics as a `Dict[str, ObjectRef]` where values + can be `CountNode`, `DurationNode`, `PercentNode`. + """ + return _ffi_api.ProfileFunction( + mod, func_name, dev.device_type, dev.device_id, warmup_iters, collectors + ) + + # We only enable this class when TVM is build with PAPI support if _ffi.get_global_func("runtime.profiling.PAPIMetricCollector", allow_missing=True) is not None: diff --git a/src/runtime/profiling.cc b/src/runtime/profiling.cc index 90d4ac64238f..000f6eac27ae 100644 --- a/src/runtime/profiling.cc +++ b/src/runtime/profiling.cc @@ -677,6 +677,63 @@ TVM_REGISTER_GLOBAL("runtime.profiling.FromJSON").set_body_typed(Report::FromJSO TVM_REGISTER_GLOBAL("runtime.profiling.DeviceWrapper").set_body_typed([](Device dev) { return DeviceWrapper(dev); }); + +PackedFunc ProfileFunction(Module mod, std::string func_name, int device_type, int device_id, + int warmup_iters, Array collectors) { + // Module::GetFunction is not const, so this lambda has to be mutable + return PackedFunc([=](TVMArgs args, TVMRetValue* ret) mutable { + PackedFunc f = mod.GetFunction(func_name); + Device dev{static_cast(device_type), device_id}; + + // warmup + for (int i = 0; i < warmup_iters; i++) { + f.CallPacked(args, ret); + } + + for (auto& collector : collectors) { + collector->Init({DeviceWrapper(dev)}); + } + std::vector> results; + results.reserve(collectors.size()); + std::vector collector_data; + collector_data.reserve(collectors.size()); + for (auto& collector : collectors) { + collector_data.push_back(collector->Start(dev)); + } + + // TODO(tkonolige): repeated calls if the runtime is small? + f.CallPacked(args, ret); + + for (size_t i = 0; i < collectors.size(); i++) { + results.push_back(collectors[i]->Stop(collector_data[i])); + } + Map combined_results; + for (auto m : results) { + for (auto p : m) { + // assume that there is no shared metric name between collectors + combined_results.Set(p.first, p.second); + } + } + *ret = combined_results; + }); +} + +TVM_REGISTER_GLOBAL("runtime.profiling.ProfileFunction") + .set_body_typed)>([](Module mod, String func_name, + int device_type, int device_id, + int warmup_iters, + Array collectors) { + if (mod->type_key() == std::string("rpc")) { + LOG(FATAL) + << "Profiling a module over RPC is not yet supported"; // because we can't send + // MetricCollectors over rpc. + throw; + } else { + return ProfileFunction(mod, func_name, device_type, device_id, warmup_iters, collectors); + } + }); + } // namespace profiling } // namespace runtime } // namespace tvm diff --git a/tests/python/unittest/test_runtime_profiling.py b/tests/python/unittest/test_runtime_profiling.py index 4e777435429b..7fa40ea29663 100644 --- a/tests/python/unittest/test_runtime_profiling.py +++ b/tests/python/unittest/test_runtime_profiling.py @@ -29,6 +29,7 @@ from tvm import rpc from tvm.contrib import utils from tvm.runtime.profiling import Report +from tvm.script import tir as T def read_csv(report): @@ -195,6 +196,52 @@ def test_report_serialization(): ) +@T.prim_func +def axpy_cpu(a: T.handle, b: T.handle, c: T.handle) -> None: + A = T.match_buffer(a, [10], "float64") + B = T.match_buffer(b, [10], "float64") + C = T.match_buffer(c, [10], "float64") + for i in range(10): + C[i] = A[i] + B[i] + + +@T.prim_func +def axpy_gpu(a: T.handle, b: T.handle, c: T.handle) -> None: + A = T.match_buffer(a, [10], "float64") + B = T.match_buffer(b, [10], "float64") + C = T.match_buffer(c, [10], "float64") + for i in T.thread_binding(0, 10, "threadIdx.x"): + C[i] = A[i] + B[i] + + +@tvm.testing.parametrize_targets("cuda", "llvm") +@pytest.mark.skipif( + tvm.get_global_func("runtime.profiling.PAPIMetricCollector", allow_missing=True) is None, + reason="PAPI profiling not enabled", +) +def test_profile_function(target, dev): + target = tvm.target.Target(target) + if str(target.kind) == "llvm": + metric = "PAPI_FP_OPS" + func = axpy_cpu + elif str(target.kind) == "cuda": + metric = ( + "cuda:::gpu__compute_memory_access_throughput.max.pct_of_peak_sustained_region:device=0" + ) + func = axpy_gpu + else: + pytest.skip(f"Target {target.kind} not supported by this test") + f = tvm.build(func, target=target) + a = tvm.nd.array(np.ones(10), device=dev) + b = tvm.nd.array(np.ones(10), device=dev) + c = tvm.nd.array(np.zeros(10), device=dev) + report = tvm.runtime.profiling.profile_function( + f, dev, [tvm.runtime.profiling.PAPIMetricCollector({dev: [metric]})] + )(a, b, c) + assert metric in report.keys() + assert report[metric].value > 0 + + if __name__ == "__main__": import sys import pytest