[CUTLASS] Add parallel split-k support to wgrad#10185
Conversation
|
If you want to investigate accuracy issue, i suggest you compare both cutlass and cudnn with a naive fp64 or fp32 version. |
commit 60b73a91b79d644d8c95f682eedaf47a89abba0d
Author: Masahiro Masuda <masahi129@gmail.com>
Date: Tue Feb 8 10:43:11 2022 +0900
pylint
commit ae2e718
Author: Masahiro Masuda <masahi129@gmail.com>
Date: Sun Feb 6 14:51:52 2022 +0900
Add split-k support for wgrad
commit 43820d5
Author: Masahiro Masuda <masahi129@gmail.com>
Date: Sun Feb 6 10:07:34 2022 +0900
fix and add doc
commit 446a95b
Author: Masahiro Masuda <masahi129@gmail.com>
Date: Sun Feb 6 09:48:38 2022 +0900
dw conv2d properly supported for wgrad
commit adc4e22
Author: Masahiro Masuda <masahi129@gmail.com>
Date: Sat Feb 5 16:32:42 2022 +0900
fix overwriting template
commit 040eab0
Author: Masahiro Masuda <masahi129@gmail.com>
Date: Sat Feb 5 16:06:27 2022 +0900
black
commit e5a07c2
Author: Masahiro Masuda <masahi129@gmail.com>
Date: Sat Feb 5 16:03:10 2022 +0900
add reduction in profiler
commit be89334
Author: Masahiro Masuda <masahi129@gmail.com>
Date: Sat Feb 5 06:58:03 2022 +0900
adding split k reduction to conv2d profiler
commit ae09b0f
Author: Masahiro Masuda <masahi129@gmail.com>
Date: Fri Feb 4 11:52:59 2022 +0900
fixed conv2d_backward_weight typerel for dw conv2d
commit 16fe531
Author: Masahiro Masuda <masahi129@gmail.com>
Date: Thu Feb 3 12:59:22 2022 +0900
wip
commit 2167c25
Author: Masahiro Masuda <masahi129@gmail.com>
Date: Thu Feb 3 04:22:19 2022 +0900
fix conv2d type rel for depth wise and grouped conv2d
commit 14b12e5
Author: Masahiro Masuda <masahi129@gmail.com>
Date: Fri Feb 4 05:01:03 2022 +0900
remove split_k.py
commit b141271
Author: Masahiro Masuda <masahi129@gmail.com>
Date: Fri Feb 4 04:48:21 2022 +0900
workaround for invalid split_k_slice
commit 6e4c7e1
Author: Masahiro Masuda <masahi129@gmail.com>
Date: Fri Feb 4 02:43:58 2022 +0900
support split k in profiler
commit 2eb1cf4
Author: Masahiro Masuda <masahi129@gmail.com>
Date: Fri Feb 4 02:31:03 2022 +0900
improvement
commit 0bce8f3
Author: Masahiro Masuda <masahi129@gmail.com>
Date: Thu Feb 3 18:20:12 2022 +0900
fixed for fp16 output
commit 30df1bd
Author: Masahiro Masuda <masahi129@gmail.com>
Date: Thu Feb 3 17:50:33 2022 +0900
fp32 output works
commit 7a51995
Author: Masahiro Masuda <masahi129@gmail.com>
Date: Thu Feb 3 14:30:22 2022 +0900
fix
commit 4a383e2
Author: Masahiro Masuda <masahi129@gmail.com>
Date: Thu Feb 3 14:05:24 2022 +0900
update c++ codegen
commit 6206e38
Author: Masahiro Masuda <masahi129@gmail.com>
Date: Thu Feb 3 13:46:05 2022 +0900
wip
commit 0ece49b
Author: Masahiro Masuda <masahi129@gmail.com>
Date: Thu Feb 3 03:05:21 2022 +0900
wip
commit 08a6147
Author: Masahiro Masuda <masahi129@gmail.com>
Date: Wed Feb 2 13:10:21 2022 +0900
test worked with fp32 output
commit 084d5c4
Author: Masahiro Masuda <masahi129@gmail.com>
Date: Wed Feb 2 12:35:18 2022 +0900
fix compile error for fprop
commit 31f2543
Author: Masahiro Masuda <masahi129@gmail.com>
Date: Wed Feb 2 12:18:06 2022 +0900
compiled
commit c2098e7
Author: Masahiro Masuda <masahi129@gmail.com>
Date: Wed Feb 2 11:11:43 2022 +0900
wip
commit a145850
Author: Masahiro Masuda <masahi129@gmail.com>
Date: Sun Feb 6 14:46:16 2022 +0900
fixed for sm75
commit 6151506
Author: Masahiro Masuda <masahi129@gmail.com>
Date: Sun Feb 6 14:32:46 2022 +0900
all tests work
commit 041c094
Author: Masahiro Masuda <masahi129@gmail.com>
Date: Sun Feb 6 14:19:09 2022 +0900
dw conv2d properly supported for wgrad
commit 2191918
Author: Masahiro Masuda <masahi129@gmail.com>
Date: Wed Feb 2 09:14:05 2022 +0900
wgrad tests now work under pytest
commit 78f76df
Author: Masahiro Masuda <masahi129@gmail.com>
Date: Wed Feb 2 07:31:54 2022 +0900
run black
commit 0a82149
Author: Masahiro Masuda <masahi129@gmail.com>
Date: Wed Feb 2 06:12:39 2022 +0900
[CUTLASS] Add wgrad support (without split-k)
5ec73cb to
d11f9bc
Compare
|
cc @mbs-octoml interesting example for perf work |
|
Hi Masa, This is amazing progress. Some questions on the known issues:
|
|
Hi Manish,
The benchmark result I linked above show accuracy difference in the last two columns. Most workload have some differences, except for some deeper layers in batch = 8 which showed exact match. It seems deeper layers, those having small spatial size and large channels, have generally less accuracy problems. The differences become much bigger for batch = 256. So it kind of works but not quite, it is very hard to debug. The profiler in cutlass doesn't report any accuracy problem, which is another mystery. It could be TVM's use of cuDNN wgrad having some issues.
The issue is memory reuse across multiple calls. The way we integrate cuDNN and cutlass are significantly different. I tried to apply a similar memory management strategy we use for cuDNN to the JIT-generated cutlass, but as I said above I'm having strange issues.
Yes, I haven't grokked your note in that thread. I just tried a dumb strategy in my benchmark and it already shows good performance. I didn't pursue perf improvement further, since the accuracy problem was more concerning. |
|
On accuracy, floating point additions are not associative. The change the order can change the result. Parallel reduction does change the order of accumulation over GEMM-K (NPQ). Thus, some change between runs is expected. I don't have a guidance on what threshold to set in checking relative error. I would take Haicheng's suggestions here and follow:
CUTLASS profiler uses integer input to initialize tensors and matrices. This is to make the error checking easier. You can also use the CUTLASS profiler approach to make sure there are no functional error, i.e., try the operation on integer input. |
|
Actually, accuracy difference was there even before I added parallel split-k to wgrad. And that the result got closer to cuDNN after adding split-k. So I believe the issue is not in parallel reduction, there is something off elsewhere. I have seen some workload where cuDNN uses cutlass's wgrad and reduction kernel, even in that case there was difference. Probably I should look at how TVM is using cuDNN wgrad first. I haven't applied fp32 wgrad on large inputs, for small ones we used in the unit test, the result looked good. We also have an option of comparing against TVM native results, which I only looked briefly.
That's very interesting... I didn't know that. I can definitely try, thanks. |
* [CUTLASS] Add split-k support to wgrad
commit 60b73a91b79d644d8c95f682eedaf47a89abba0d
Author: Masahiro Masuda <masahi129@gmail.com>
Date: Tue Feb 8 10:43:11 2022 +0900
pylint
commit ae2e718
Author: Masahiro Masuda <masahi129@gmail.com>
Date: Sun Feb 6 14:51:52 2022 +0900
Add split-k support for wgrad
commit 43820d5
Author: Masahiro Masuda <masahi129@gmail.com>
Date: Sun Feb 6 10:07:34 2022 +0900
fix and add doc
commit 446a95b
Author: Masahiro Masuda <masahi129@gmail.com>
Date: Sun Feb 6 09:48:38 2022 +0900
dw conv2d properly supported for wgrad
commit adc4e22
Author: Masahiro Masuda <masahi129@gmail.com>
Date: Sat Feb 5 16:32:42 2022 +0900
fix overwriting template
commit 040eab0
Author: Masahiro Masuda <masahi129@gmail.com>
Date: Sat Feb 5 16:06:27 2022 +0900
black
commit e5a07c2
Author: Masahiro Masuda <masahi129@gmail.com>
Date: Sat Feb 5 16:03:10 2022 +0900
add reduction in profiler
commit be89334
Author: Masahiro Masuda <masahi129@gmail.com>
Date: Sat Feb 5 06:58:03 2022 +0900
adding split k reduction to conv2d profiler
commit ae09b0f
Author: Masahiro Masuda <masahi129@gmail.com>
Date: Fri Feb 4 11:52:59 2022 +0900
fixed conv2d_backward_weight typerel for dw conv2d
commit 16fe531
Author: Masahiro Masuda <masahi129@gmail.com>
Date: Thu Feb 3 12:59:22 2022 +0900
wip
commit 2167c25
Author: Masahiro Masuda <masahi129@gmail.com>
Date: Thu Feb 3 04:22:19 2022 +0900
fix conv2d type rel for depth wise and grouped conv2d
commit 14b12e5
Author: Masahiro Masuda <masahi129@gmail.com>
Date: Fri Feb 4 05:01:03 2022 +0900
remove split_k.py
commit b141271
Author: Masahiro Masuda <masahi129@gmail.com>
Date: Fri Feb 4 04:48:21 2022 +0900
workaround for invalid split_k_slice
commit 6e4c7e1
Author: Masahiro Masuda <masahi129@gmail.com>
Date: Fri Feb 4 02:43:58 2022 +0900
support split k in profiler
commit 2eb1cf4
Author: Masahiro Masuda <masahi129@gmail.com>
Date: Fri Feb 4 02:31:03 2022 +0900
improvement
commit 0bce8f3
Author: Masahiro Masuda <masahi129@gmail.com>
Date: Thu Feb 3 18:20:12 2022 +0900
fixed for fp16 output
commit 30df1bd
Author: Masahiro Masuda <masahi129@gmail.com>
Date: Thu Feb 3 17:50:33 2022 +0900
fp32 output works
commit 7a51995
Author: Masahiro Masuda <masahi129@gmail.com>
Date: Thu Feb 3 14:30:22 2022 +0900
fix
commit 4a383e2
Author: Masahiro Masuda <masahi129@gmail.com>
Date: Thu Feb 3 14:05:24 2022 +0900
update c++ codegen
commit 6206e38
Author: Masahiro Masuda <masahi129@gmail.com>
Date: Thu Feb 3 13:46:05 2022 +0900
wip
commit 0ece49b
Author: Masahiro Masuda <masahi129@gmail.com>
Date: Thu Feb 3 03:05:21 2022 +0900
wip
commit 08a6147
Author: Masahiro Masuda <masahi129@gmail.com>
Date: Wed Feb 2 13:10:21 2022 +0900
test worked with fp32 output
commit 084d5c4
Author: Masahiro Masuda <masahi129@gmail.com>
Date: Wed Feb 2 12:35:18 2022 +0900
fix compile error for fprop
commit 31f2543
Author: Masahiro Masuda <masahi129@gmail.com>
Date: Wed Feb 2 12:18:06 2022 +0900
compiled
commit c2098e7
Author: Masahiro Masuda <masahi129@gmail.com>
Date: Wed Feb 2 11:11:43 2022 +0900
wip
commit a145850
Author: Masahiro Masuda <masahi129@gmail.com>
Date: Sun Feb 6 14:46:16 2022 +0900
fixed for sm75
commit 6151506
Author: Masahiro Masuda <masahi129@gmail.com>
Date: Sun Feb 6 14:32:46 2022 +0900
all tests work
commit 041c094
Author: Masahiro Masuda <masahi129@gmail.com>
Date: Sun Feb 6 14:19:09 2022 +0900
dw conv2d properly supported for wgrad
commit 2191918
Author: Masahiro Masuda <masahi129@gmail.com>
Date: Wed Feb 2 09:14:05 2022 +0900
wgrad tests now work under pytest
commit 78f76df
Author: Masahiro Masuda <masahi129@gmail.com>
Date: Wed Feb 2 07:31:54 2022 +0900
run black
commit 0a82149
Author: Masahiro Masuda <masahi129@gmail.com>
Date: Wed Feb 2 06:12:39 2022 +0900
[CUTLASS] Add wgrad support (without split-k)
* pylint
* add more doc
* more doc clarification
Building on #10177, this adds parallel split-k support to wgrad.
@comaniac @Laurawly @junrushao1994 @vinx13 @YuchenJin @hwu36 @manishucsd
Split-k is described in https://github.com/NVIDIA/cutlass/blob/master/media/docs/efficient_gemm.md#parallelized-reductions.
This is my first experience using split-k in cutlass or any other API. Wgrad is particularly interesting for split-k since the implicit gemm K dimension is really large in wgrad (
N * P * QwherePandQare the output H and W). Without split-k, wgrad on large spatial inputs is extremely slow.For now, I'm not trying anything smart to pick the split-k parameter, instead we ask users to provide possible candidates. I tuned over
[1, 4, 8, 16, 32, 64]below and that already showed excellent performance. The benchmark code is here.Benchmark result against cuDNN. Note that currently there are non-trivial difference in cuDNN and TVM + cutlass outputs, especially for the larger batch size. I didn't find anything obviously wrong in the generated code and I gave up fixing accuracy difference at some point. Also note that difference is not due to parallel-split-k, even in a normal case the results were different (and actually improved after split-k lol).
The result showed that cutlass winning across the board (
Profiler timevscuDNNcolumns, but again, the results do not match exactly). However, there is a serious problem when cutlass wgrad + split-k kernels are called from TVM (TVM + CUTLASScolumn): Split-k requires large workspace, and the space requirement grows linearly withsplit-k-slicesparameter. Right now we naively allocate the workspace on every cutlass kernel call on each run, while for cuDNN we have a simple workspace memory reuse mechanism implemented in (together with a thread local storage)tvm/src/runtime/contrib/cudnn/cudnn_utils.cc
Lines 153 to 161 in 211291f
I attempted adding a simple workspace memory management in https://github.com/masahi/tvm/compare/cutlass-split-k...masahi:cutlass-workspace?expand=1, it kind of works in terms of the expected perf improvement. However, I'm getting segfault and other strange issues. I'm a bit confused as to what the right behavior should be for a thread local memory manager in the context of JIT- generated and compiled multiple translation units. Let me know if you have any thoughts on this issue.
Known issues and TODO