[Relay][Training] Add more missing gradients#6767
Conversation
|
Please try to pass lint locally first~ |
|
|
||
|
|
||
| @register_gradient("take") | ||
| def take_grad(orig, grad): |
There was a problem hiding this comment.
you can get by by defining a 'put' operator, that put a scalar into an index of a tensor, and leave other palces unchanged. put and take has some classic property which I assume will be better for the optimizer. It also allow other optimization (e.g. put and reduce_sum, using grad + (put vala at idxa in 0_array) + (put valb at idxb in 0_array) will be collapsed into a long chain of put on grad, allowing COW to kick in and all take grad mutation update (instead of creating another tensor).
There was a problem hiding this comment.
This is a good point that I was wondering about. The loop is basically just implementing a put operation (like I described in the comment), so it would make sense to have it be a separate op since I imagine it will be useful in general. Should I remove this gradient for now, or keep it and replace it with put once I implement it?
|
merging per reviews, feel free to send followup improvements per @MarisaKirisame 's comment |
Added the following gradients:
takereverse_reshapestacksqueezeexpand_dimsarangeAlso fixed a typo in type solver diagnostics. I had to use a Relay loop in
take_gradto support Any size in indices.cc @t-vi @SWu @MarisaKirisame