-
Notifications
You must be signed in to change notification settings - Fork 67
Add NCUObserver #253
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add NCUObserver #253
Changes from all commits
4fe1843
2183d13
225dae7
4cc68c5
3602286
397306f
3402808
35910f5
59bd891
7bb58c3
092caad
742faa4
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,57 @@ | ||
| #!/usr/bin/env python | ||
| """This is the minimal example from the README""" | ||
| import json | ||
|
|
||
| import numpy | ||
| from kernel_tuner import tune_kernel | ||
| from kernel_tuner.observers.ncu import NCUObserver | ||
|
|
||
| def tune(): | ||
|
|
||
| kernel_string = """ | ||
| __global__ void vector_add(float *c, float *a, float *b, int n) { | ||
| int i = blockIdx.x * block_size_x + threadIdx.x; | ||
| if (i<n) { | ||
| c[i] = a[i] + b[i]; | ||
| } | ||
| } | ||
| """ | ||
|
|
||
| size = 80000000 | ||
|
|
||
| a = numpy.random.randn(size).astype(numpy.float32) | ||
| b = numpy.random.randn(size).astype(numpy.float32) | ||
| c = numpy.zeros_like(b) | ||
| n = numpy.int32(size) | ||
|
|
||
| args = [c, a, b, n] | ||
|
|
||
| tune_params = dict() | ||
| tune_params["block_size_x"] = [128+64*i for i in range(15)] | ||
|
|
||
| ncu_metrics = ["dram__bytes.sum", # Counter byte # of bytes accessed in DRAM | ||
| "dram__bytes_read.sum", # Counter byte # of bytes read from DRAM | ||
| "dram__bytes_write.sum", # Counter byte # of bytes written to DRAM | ||
| "smsp__sass_thread_inst_executed_op_fadd_pred_on.sum", # Counter inst # of FADD thread instructions executed where all predicates were true | ||
| "smsp__sass_thread_inst_executed_op_ffma_pred_on.sum", # Counter inst # of FFMA thread instructions executed where all predicates were true | ||
| "smsp__sass_thread_inst_executed_op_fmul_pred_on.sum", # Counter inst # of FMUL thread instructions executed where all predicates were true | ||
| ] | ||
|
|
||
| ncuobserver = NCUObserver(metrics=ncu_metrics) | ||
|
|
||
| def total_fp32_flops(p): | ||
| return p["smsp__sass_thread_inst_executed_op_fadd_pred_on.sum"] + 2 * p["smsp__sass_thread_inst_executed_op_ffma_pred_on.sum"] + p["smsp__sass_thread_inst_executed_op_fmul_pred_on.sum"] | ||
|
|
||
| metrics = dict() | ||
| metrics["GFLOP/s"] = lambda p: (total_fp32_flops(p) / 1e9) / (p["time"]/1e3) | ||
| metrics["Expected GFLOP/s"] = lambda p: (size / 1e9) / (p["time"]/1e3) | ||
| metrics["GB/s"] = lambda p: (p["dram__bytes.sum"] / 1e9) / (p["time"]/1e3) | ||
| metrics["Expected GB/s"] = lambda p: (size*4*3 / 1e9) / (p["time"]/1e3) | ||
|
|
||
| results, env = tune_kernel("vector_add", kernel_string, size, args, tune_params, observers=[ncuobserver], metrics=metrics, iterations=7) | ||
|
|
||
| return results | ||
|
|
||
|
|
||
| if __name__ == "__main__": | ||
| tune() |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -1 +1 @@ | ||
| from .observer import BenchmarkObserver, IterationObserver, ContinuousObserver, OutputObserver | ||
| from .observer import BenchmarkObserver, IterationObserver, ContinuousObserver, OutputObserver, PrologueObserver |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,41 @@ | ||
| from kernel_tuner.observers import PrologueObserver | ||
|
|
||
| try: | ||
| import nvmetrics | ||
| except (ImportError): | ||
| nvmetrics = None | ||
|
|
||
| class NCUObserver(PrologueObserver): | ||
| """``NCUObserver`` measures performance counters. | ||
|
|
||
| The exact performance counters supported differ per GPU, some examples: | ||
|
|
||
| * "dram__bytes.sum", # Counter byte # of bytes accessed in DRAM | ||
| * "dram__bytes_read.sum", # Counter byte # of bytes read from DRAM | ||
| * "dram__bytes_write.sum", # Counter byte # of bytes written to DRAM | ||
| * "smsp__sass_thread_inst_executed_op_fadd_pred_on.sum", # Counter inst # of FADD thread instructions executed where all predicates were true | ||
| * "smsp__sass_thread_inst_executed_op_ffma_pred_on.sum", # Counter inst # of FFMA thread instructions executed where all predicates were true | ||
| * "smsp__sass_thread_inst_executed_op_fmul_pred_on.sum", # Counter inst # of FMUL thread instructions executed where all predicates were true | ||
|
|
||
| :param metrics: The metrics to observe. This should be a list of strings. | ||
| You can use ``ncu --query-metrics`` to get a list of valid metrics. | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Great! Here is an idea: Would it be much work to add a way to get the list of supported metrics from
Collaborator
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. A simple way to implement this is to let I have also thought about making this observer more user-friendly by providing presets of metrics. It is very tricky, though, as for many metrics there are mandatory suffixes like
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Aren't there at least some metrics that are available on all GPUs that could be listed? I mean to get people started using the NCUObserver at first, and then if they want more they can still query ncu --query-metrics.
Collaborator
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I have added the functionality to query metrics to These different metric types should be passed with some suffix, these are called
Collaborator
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @stijnh, how would you suggest proceeding here? |
||
| :type metrics: list[str] | ||
|
|
||
| """ | ||
|
|
||
| def __init__(self, metrics=None, device=0): | ||
| if not nvmetrics: | ||
| raise ImportError("could not import nvmetrics") | ||
|
|
||
| self.metrics = metrics | ||
| self.device = device | ||
| self.results = dict() | ||
|
|
||
| def before_start(self): | ||
| nvmetrics.measureMetricsStart(self.metrics, self.device) | ||
|
|
||
| def after_finish(self): | ||
| self.results = nvmetrics.measureMetricsStop() | ||
|
|
||
| def get_results(self): | ||
| return dict(zip(self.metrics, self.results)) | ||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I wonder whether we should change the name
NCUObserverinto something likeMetricsObserver. Eventually, we may want to also support measuring metrics on non-NVIDIA GPUs, and the observer could then pick the right backend based on whatever runner is used.Now that I mention runners, I never tested this observer with anything else than pycuda. The observer needs CUDA to be at least somewhat initialized. When testing nvmetrics with pycuda everything worked fine, but not using cupy for instance. Maybe we still need to test this (and add the tests?) or at least mention it somewhere.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
In terms of naming, I think I would prefer a more generalized NCUObserver a "ProfilingObserver" because we already use the term 'metrics' for the user-defined derived metrics. For now, I think we can keep the name NCUObserver until we have another "ProfilingObserver" and when we have that it would make sense to think about a more generic type.
I will test with cupy and cuda-python.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Works with CuPy but not yet with cuda-python.
This is the error I get with cuda-python:
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Not sure how this can happen. The initialization of the nvcuda backend is very similar to pycuda and by the time we call berfore_start() it should have been initialized for a while already.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
In the nvcuda backend, I tried to create a new context and destroy it atexit like pycuda is doing, but that doesn't change anything.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think it's fixed in the nvmerics version that I just pushed (revision c7c42130). It now calls
cuInit, which doesn't seem to hurt, even if Pycuda already initialized CUDA before. To get a valid CUDA context,cuDevicePrimaryCtxRetainis used.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Awesome! I just pulled and rebuilt the nvmetrics library and tested it with every CUDA backend. It all works now!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
That's great! I plan to make a 1.0 release of nvmetrics to go along with this MR.