Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .github/workflows/test-python-package.yml
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ jobs:
- name: Setup Nox
uses: fjwillemsen/setup-nox2@v3.0.0
- name: Setup Poetry
uses: Gr1N/setup-poetry@v8
uses: Gr1N/setup-poetry@v9
- run: poetry self add poetry-plugin-export
- name: Run tests with Nox
run: |
Expand Down
8 changes: 8 additions & 0 deletions doc/source/observers.rst
Original file line number Diff line number Diff line change
Expand Up @@ -112,3 +112,11 @@ More information about PMT can be found here: https://git.astron.nl/RD/pmt/



NCUObserver
~~~~~~~~~~~

The NCUObserver can be used to automatically extract performance counters during tuning using Nvidia's NsightCompute profiler.
The NCUObserver relies on an intermediate library, which can be found here: https://github.com/nlesc-recruit/nvmetrics
Comment on lines +118 to +119
Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
The NCUObserver can be used to automatically extract performance counters during tuning using Nvidia's NsightCompute profiler.
The NCUObserver relies on an intermediate library, which can be found here: https://github.com/nlesc-recruit/nvmetrics
The NCUObserver can be used to automatically extract performance counters during tuning using Nvidia's CUDA Profiling Tools Interface (CUPTI) library, to offer insights into performance characteristics like memory access patterns and instruction execution statistics. It behaves much like the Nvidia Nsight Compute CLI (ncu) application, hence the name NCUObserver. This observer relies on an intermediate library, which can be found here: https://github.com/nlesc-recruit/nvmetrics

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I wonder whether we should change the name NCUObserver into something like MetricsObserver. Eventually, we may want to also support measuring metrics on non-NVIDIA GPUs, and the observer could then pick the right backend based on whatever runner is used.

Now that I mention runners, I never tested this observer with anything else than pycuda. The observer needs CUDA to be at least somewhat initialized. When testing nvmetrics with pycuda everything worked fine, but not using cupy for instance. Maybe we still need to test this (and add the tests?) or at least mention it somewhere.

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In terms of naming, I think I would prefer a more generalized NCUObserver a "ProfilingObserver" because we already use the term 'metrics' for the user-defined derived metrics. For now, I think we can keep the name NCUObserver until we have another "ProfilingObserver" and when we have that it would make sense to think about a more generic type.

I will test with cupy and cuda-python.

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Works with CuPy but not yet with cuda-python.

This is the error I get with cuda-python:

Using: NVIDIA GeForce RTX 3050 Ti Laptop GPU
Error while benchmarking: vector_add
Error while compiling or benchmarking, see source files:  kernel_tuner/examples/cuda/temp_oq2ba5hb.c
Traceback (most recent call last):
  File " kernel_tuner/examples/cuda/vector_add_observers_ncu.py", line 57, in <module>
    tune()
  File " kernel_tuner/examples/cuda/vector_add_observers_ncu.py", line 51, in tune
    results, env = tune_kernel("vector_add", kernel_string, size, args, tune_params, observers=[ncuobserver], metrics=metrics, iterations=7, lang='nvcuda')
                   ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File " kernel_tuner/kernel_tuner/interface.py", line 678, in tune_kernel
    results = strategy.tune(searchspace, runner, tuning_options)
              ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File " kernel_tuner/kernel_tuner/strategies/brute_force.py", line 10, in tune
    return runner.run(searchspace.sorted_list(), tuning_options)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File " kernel_tuner/kernel_tuner/runners/sequential.py", line 87, in run
    self.dev.compile_and_benchmark(self.kernel_source, self.gpu_args, params, self.kernel_options, tuning_options)
  File " kernel_tuner/kernel_tuner/core.py", line 614, in compile_and_benchmark
    raise e
  File " kernel_tuner/kernel_tuner/core.py", line 603, in compile_and_benchmark
    self.benchmark(func, gpu_args, instance, verbose, to.objective, skip_nvml_setting=False)
  File " kernel_tuner/kernel_tuner/core.py", line 472, in benchmark
    raise e
  File " kernel_tuner/kernel_tuner/core.py", line 437, in benchmark
    self.benchmark_prologue(func, gpu_args, instance.threads, instance.grid, result)
  File " kernel_tuner/kernel_tuner/core.py", line 355, in benchmark_prologue
    obs.before_start()
  File " kernel_tuner/kernel_tuner/observers/ncu.py", line 35, in before_start
    nvmetrics.measureMetricsStart(self.metrics, self.device)
RuntimeError: /home/ben/documents/kernel_tuner/nvmetrics/lib/nv_metrics.cpp:147: error: function cuDeviceGet(&cuDevice, deviceNum) failed with error

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not sure how this can happen. The initialization of the nvcuda backend is very similar to pycuda and by the time we call berfore_start() it should have been initialized for a while already.

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In the nvcuda backend, I tried to create a new context and destroy it atexit like pycuda is doing, but that doesn't change anything.

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it's fixed in the nvmerics version that I just pushed (revision c7c42130). It now calls cuInit, which doesn't seem to hurt, even if Pycuda already initialized CUDA before. To get a valid CUDA context, cuDevicePrimaryCtxRetain is used.

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Awesome! I just pulled and rebuilt the nvmetrics library and tested it with every CUDA backend. It all works now!

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That's great! I plan to make a 1.0 release of nvmetrics to go along with this MR.


.. autoclass:: kernel_tuner.observers.ncu.NCUObserver

57 changes: 57 additions & 0 deletions examples/cuda/vector_add_observers_ncu.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
#!/usr/bin/env python
"""This is the minimal example from the README"""
import json

import numpy
from kernel_tuner import tune_kernel
from kernel_tuner.observers.ncu import NCUObserver

def tune():

kernel_string = """
__global__ void vector_add(float *c, float *a, float *b, int n) {
int i = blockIdx.x * block_size_x + threadIdx.x;
if (i<n) {
c[i] = a[i] + b[i];
}
}
"""

size = 80000000

a = numpy.random.randn(size).astype(numpy.float32)
b = numpy.random.randn(size).astype(numpy.float32)
c = numpy.zeros_like(b)
n = numpy.int32(size)

args = [c, a, b, n]

tune_params = dict()
tune_params["block_size_x"] = [128+64*i for i in range(15)]

ncu_metrics = ["dram__bytes.sum", # Counter byte # of bytes accessed in DRAM
"dram__bytes_read.sum", # Counter byte # of bytes read from DRAM
"dram__bytes_write.sum", # Counter byte # of bytes written to DRAM
"smsp__sass_thread_inst_executed_op_fadd_pred_on.sum", # Counter inst # of FADD thread instructions executed where all predicates were true
"smsp__sass_thread_inst_executed_op_ffma_pred_on.sum", # Counter inst # of FFMA thread instructions executed where all predicates were true
"smsp__sass_thread_inst_executed_op_fmul_pred_on.sum", # Counter inst # of FMUL thread instructions executed where all predicates were true
]

ncuobserver = NCUObserver(metrics=ncu_metrics)

def total_fp32_flops(p):
return p["smsp__sass_thread_inst_executed_op_fadd_pred_on.sum"] + 2 * p["smsp__sass_thread_inst_executed_op_ffma_pred_on.sum"] + p["smsp__sass_thread_inst_executed_op_fmul_pred_on.sum"]

metrics = dict()
metrics["GFLOP/s"] = lambda p: (total_fp32_flops(p) / 1e9) / (p["time"]/1e3)
metrics["Expected GFLOP/s"] = lambda p: (size / 1e9) / (p["time"]/1e3)
metrics["GB/s"] = lambda p: (p["dram__bytes.sum"] / 1e9) / (p["time"]/1e3)
metrics["Expected GB/s"] = lambda p: (size*4*3 / 1e9) / (p["time"]/1e3)

results, env = tune_kernel("vector_add", kernel_string, size, args, tune_params, observers=[ncuobserver], metrics=metrics, iterations=7)

return results


if __name__ == "__main__":
tune()
43 changes: 28 additions & 15 deletions kernel_tuner/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
from kernel_tuner.backends.opencl import OpenCLFunctions
from kernel_tuner.backends.hip import HipFunctions
from kernel_tuner.observers.nvml import NVMLObserver
from kernel_tuner.observers.observer import ContinuousObserver, OutputObserver
from kernel_tuner.observers.observer import ContinuousObserver, OutputObserver, PrologueObserver

try:
import torch
Expand Down Expand Up @@ -314,11 +314,13 @@ def __init__(
)
else:
raise ValueError("Sorry, support for languages other than CUDA, OpenCL, HIP, C, and Fortran is not implemented yet")
self.dev = dev

# look for NVMLObserver in observers, if present, enable special tunable parameters through nvml
self.use_nvml = False
self.continuous_observers = []
self.output_observers = []
self.prologue_observers = []
if observers:
for obs in observers:
if isinstance(obs, NVMLObserver):
Expand All @@ -328,49 +330,61 @@ def __init__(
self.continuous_observers.append(obs.continuous_observer)
if isinstance(obs, OutputObserver):
self.output_observers.append(obs)
if isinstance(obs, PrologueObserver):
self.prologue_observers.append(obs)

# Take list of observers from self.dev because Backends tend to add their own observer
self.benchmark_observers = [
obs for obs in self.dev.observers if not isinstance(obs, (ContinuousObserver, PrologueObserver))
]

self.iterations = iterations

self.lang = lang
self.dev = dev
self.units = dev.units
self.name = dev.name
self.max_threads = dev.max_threads
if not quiet:
print("Using: " + self.dev.name)

def benchmark_prologue(self, func, gpu_args, threads, grid, result):
"""Benchmark prologue one kernel execution per PrologueObserver"""

for obs in self.prologue_observers:
Comment thread
csbnw marked this conversation as resolved.
self.dev.synchronize()
obs.before_start()
self.dev.run_kernel(func, gpu_args, threads, grid)
self.dev.synchronize()
obs.after_finish()
result.update(obs.get_results())

def benchmark_default(self, func, gpu_args, threads, grid, result):
"""Benchmark one kernel execution at a time."""
observers = [
obs for obs in self.dev.observers if not isinstance(obs, ContinuousObserver)
]
"""Benchmark one kernel execution for 'iterations' at a time"""

self.dev.synchronize()
for _ in range(self.iterations):
for obs in observers:
for obs in self.benchmark_observers:
obs.before_start()
self.dev.synchronize()
self.dev.start_event()
self.dev.run_kernel(func, gpu_args, threads, grid)
self.dev.stop_event()
for obs in observers:
for obs in self.benchmark_observers:
obs.after_start()
while not self.dev.kernel_finished():
for obs in observers:
for obs in self.benchmark_observers:
obs.during()
time.sleep(1e-6) # one microsecond
self.dev.synchronize()
for obs in observers:
for obs in self.benchmark_observers:
obs.after_finish()

for obs in observers:
for obs in self.benchmark_observers:
result.update(obs.get_results())

def benchmark_continuous(self, func, gpu_args, threads, grid, result, duration):
"""Benchmark continuously for at least 'duration' seconds"""
iterations = int(np.ceil(duration / (result["time"] / 1000)))
# print(f"{iterations=} {(result['time']/1000)=}")
self.dev.synchronize()
for obs in self.continuous_observers:
obs.before_start()
Expand Down Expand Up @@ -420,9 +434,8 @@ def benchmark(self, func, gpu_args, instance, verbose, objective, skip_nvml_sett

result = {}
try:
self.benchmark_default(
func, gpu_args, instance.threads, instance.grid, result
)
self.benchmark_prologue(func, gpu_args, instance.threads, instance.grid, result)
self.benchmark_default(func, gpu_args, instance.threads, instance.grid, result)

if self.continuous_observers:
duration = 1
Expand Down
2 changes: 1 addition & 1 deletion kernel_tuner/observers/__init__.py
Original file line number Diff line number Diff line change
@@ -1 +1 @@
from .observer import BenchmarkObserver, IterationObserver, ContinuousObserver, OutputObserver
from .observer import BenchmarkObserver, IterationObserver, ContinuousObserver, OutputObserver, PrologueObserver
41 changes: 41 additions & 0 deletions kernel_tuner/observers/ncu.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
from kernel_tuner.observers import PrologueObserver

try:
import nvmetrics
except (ImportError):
nvmetrics = None

class NCUObserver(PrologueObserver):
"""``NCUObserver`` measures performance counters.

The exact performance counters supported differ per GPU, some examples:

* "dram__bytes.sum", # Counter byte # of bytes accessed in DRAM
* "dram__bytes_read.sum", # Counter byte # of bytes read from DRAM
* "dram__bytes_write.sum", # Counter byte # of bytes written to DRAM
* "smsp__sass_thread_inst_executed_op_fadd_pred_on.sum", # Counter inst # of FADD thread instructions executed where all predicates were true
* "smsp__sass_thread_inst_executed_op_ffma_pred_on.sum", # Counter inst # of FFMA thread instructions executed where all predicates were true
* "smsp__sass_thread_inst_executed_op_fmul_pred_on.sum", # Counter inst # of FMUL thread instructions executed where all predicates were true

:param metrics: The metrics to observe. This should be a list of strings.
You can use ``ncu --query-metrics`` to get a list of valid metrics.
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Great! Here is an idea: Would it be much work to add a way to get the list of supported metrics from nvmetrics? Something like NCUObserver.get_available_metrics() could be useful to check which metrics are supported by the current GPU. Not a critical issue, just a thought

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

A simple way to implement this is to let NCUObserver call the ncu command-line profiler with --query-metrics. It would arguably be better if nvmetrics could do this, as it links to CUPTI etc. anyway. I will try to figure out whether this is possible.

I have also thought about making this observer more user-friendly by providing presets of metrics. It is very tricky, though, as for many metrics there are mandatory suffixes like .sum or .average. We don't know what kind of metrics users are interested in. Moreover, the available metrics may differ per architecture. I think that this observer should mainly be used by expert users that profiled their code before (using ncu or the GUI) and found some metrics that they would like to know for a larger parameter space.

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Aren't there at least some metrics that are available on all GPUs that could be listed? I mean to get people started using the NCUObserver at first, and then if they want more they can still query ncu --query-metrics.

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I have added the functionality to query metrics to nvmetrics, including a Python interface. You can now do:

nvmetrics.queryMetrics(nvmetrics.NVPW_METRIC_TYPE_COUNTER)
nvmetrics.queryMetrics(nvmetrics.NVPW_METRIC_TYPE_RATIO)
nvmetrics.queryMetrics(nvmetrics.NVPW_METRIC_TYPE_THROUGHPUT)

These different metric types should be passed with some suffix, these are called NVPW_RollupOp and NVPW_Submetric, but not all combinations are valid, see this struct:

    typedef struct NVPW_MetricEvalRequest
    {
        /// the metric index as in 'NVPW_MetricsEvaluator_GetMetricNames'
        size_t metricIndex;
        /// one of 'NVPW_MetricType'
        uint8_t metricType;
        /// one of 'NVPW_RollupOp', required for Counter and Throughput, doesn't apply to Ratio
        uint8_t rollupOp;
        /// one of 'NVPW_Submetric', required for Ratio and Throughput, optional for Counter
        uint16_t submetric;
    } NVPW_MetricEvalRequest;

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@stijnh, how would you suggest proceeding here?
We could for instance use the same metrics that the example uses as default.

:type metrics: list[str]

"""

def __init__(self, metrics=None, device=0):
if not nvmetrics:
raise ImportError("could not import nvmetrics")

self.metrics = metrics
self.device = device
self.results = dict()

def before_start(self):
nvmetrics.measureMetricsStart(self.metrics, self.device)

def after_finish(self):
self.results = nvmetrics.measureMetricsStop()

def get_results(self):
return dict(zip(self.metrics, self.results))
11 changes: 11 additions & 0 deletions kernel_tuner/observers/observer.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,4 +57,15 @@ def process_output(self, answer, output):
"""
pass

class PrologueObserver(BenchmarkObserver):
"""Observer that measures something in a seperate kernel invocation prior to the normal benchmark."""

@abstractmethod
def before_start(self):
Comment thread
csbnw marked this conversation as resolved.
"""prologue start is called before the kernel starts"""
pass

@abstractmethod
def after_finish(self):
Comment thread
csbnw marked this conversation as resolved.
"""prologue finish is called after the kernel has finished execution"""
pass