Expected behavior
The lowering time of the given case should be around 10 seconds.
Actual behavior
The lowering time is more than 550 seconds.
Environment
Any environment with commit commit 101e3a4 (#13217 ) or later.
Steps to reproduce
The script:
import time
import tvm
from tvm import topi
class Timer :
def __init__ (self , msg ):
self .msg = msg
print (f"{ msg } ..." , flush = True )
def __enter__ (self ):
self .start = time .time ()
def __exit__ (self , * args ):
print (f"{ self .msg } ...{ time .time () - self .start :.2f} s" , flush = True )
def resize2d_dx_compute (inp , dy ):
"""compute definition for resize2d_dx op"""
size = (64 , 32 )
layout = "NCHW"
method = "cubic"
coord_trans = "half_pixel"
rounding_method = ""
cubic_alpha = - 0.75
cubic_exclude = 0
out_dtype = "float32"
out = topi .image .resize2d (
inp ,
(None , None , None , None ),
size ,
layout ,
method ,
coord_trans ,
rounding_method ,
bicubic_alpha = cubic_alpha ,
bicubic_exclude = cubic_exclude ,
out_dtype = out_dtype ,
)
grads = tvm .te .gradient (out , [inp ], head = dy )
return grads
inp = tvm .te .placeholder ((32 , 3 , 32 , 32 ), name = "inp" )
dy = tvm .te .placeholder ((32 , 3 , 64 , 32 ), name = "dy" )
with Timer ("te.gradient" ):
grads = resize2d_dx_compute (inp , dy )
# This problem is platform-independent.
with Timer ("schedule" ):
sch = topi .x86 .injective .schedule_injective (grads )
with Timer ("lower" ):
print (tvm .lower (sch , [inp , dy , grads [0 ]], simple_mode = True ))
Switch to a commit before 101e3a4 ([TIR][Transform] Optional data-flow analysis in RemoveNoOp #13217 ) and run the script.
Checkout the commit 101e3a4 ([TIR][Transform] Optional data-flow analysis in RemoveNoOp #13217 ) and run again.
Here are also the lowered IR without and with this commit:
Without this commit:
@main = primfn(inp_1: handle, dy_1: handle, resize.inp.grad_1: handle) -> ()
attr = {"from_legacy_te_schedule": True, "global_symbol": "main", "tir.noalias": True}
buffers = {inp: Buffer(inp_2: Pointer(float32), float32, [32, 3, 32, 32], []),
dy: Buffer(dy_2: Pointer(float32), float32, [32, 3, 64, 32], []),
resize.inp.grad: Buffer(resize.inp.grad_2: Pointer(float32), float32, [32, 3, 32, 32], [])}
buffer_map = {inp_1: inp, dy_1: dy, resize.inp.grad_1: resize.inp.grad} {
for (ax0.ax1.fused: int32, 0, 96) "parallel" {
for (ax2: int32, 0, 32) {
for (ax3.outer: int32, 0, 2) {
resize.inp.grad_3: Buffer(resize.inp.grad_2, float32, [98304], [])[ramp((((ax0.ax1.fused*1024) + (ax2*32)) + (ax3.outer*16)), 1, 16)] = broadcast(0f32, 16)
for (n0_n0_k2.shifted.shifted: int32, 0, 64) {
for (n1_n1_k3.shifted.shifted: int32, 0, 32) {
for (ax3.inner.s: int32, 0, 16) {
let cse_var_3: float32 = cast(float32, n1_n1_k3.shifted.shifted)
let cse_var_2: int32 = ((ax3.outer*16) + ax3.inner.s)
let cse_var_1: float32 = (((cast(float32, n0_n0_k2.shifted.shifted) + 0.5f32)*0.5f32) - 0.5f32)
if ((((((ax2 - max(min((cast(int32, @tir.floor(cse_var_1, dtype=float32)) - 1), 31), 0)) == 0) && (((((cse_var_2 - max(min((cast(int32, @tir.floor(cse_var_3, dtype=float32)) - 1), 31), 0)) == 0) || ((cse_var_2 - max(min(cast(int32, @tir.floor(cse_var_3, dtype=float32)), 31), 0)) == 0)) || ((cse_var_2 - max(min((cast(int32, @tir.floor(cse_var_3, dtype=float32)) + 1), 31), 0)) == 0)) || ((cse_var_2 - max(min((cast(int32, @tir.floor(cse_var_3, dtype=float32)) + 2), 31), 0)) == 0))) || (((ax2 - max(min(cast(int32, @tir.floor(cse_var_1, dtype=float32)), 31), 0)) == 0) && (((((cse_var_2 - max(min((cast(int32, @tir.floor(cse_var_3, dtype=float32)) - 1), 31), 0)) == 0) || ((cse_var_2 - max(min(cast(int32, @tir.floor(cse_var_3, dtype=float32)), 31), 0)) == 0)) || ((cse_var_2 - max(min((cast(int32, @tir.floor(cse_var_3, dtype=float32)) + 1), 31), 0)) == 0)) || ((cse_var_2 - max(min((cast(int32, @tir.floor(cse_var_3, dtype=float32)) + 2), 31), 0)) == 0)))) || (((ax2 - max(min((cast(int32, @tir.floor(cse_var_1, dtype=float32)) + 1), 31), 0)) == 0) && (((((cse_var_2 - max(min((cast(int32, @tir.floor(cse_var_3, dtype=float32)) - 1), 31), 0)) == 0) || ((cse_var_2 - max(min(cast(int32, @tir.floor(cse_var_3, dtype=float32)), 31), 0)) == 0)) || ((cse_var_2 - max(min((cast(int32, @tir.floor(cse_var_3, dtype=float32)) + 1), 31), 0)) == 0)) || ((cse_var_2 - max(min((cast(int32, @tir.floor(cse_var_3, dtype=float32)) + 2), 31), 0)) == 0)))) || (((ax2 - max(min((cast(int32, @tir.floor(cse_var_1, dtype=float32)) + 2), 31), 0)) == 0) && (((((cse_var_2 - max(min((cast(int32, @tir.floor(cse_var_3, dtype=float32)) - 1), 31), 0)) == 0) || ((cse_var_2 - max(min(cast(int32, @tir.floor(cse_var_3, dtype=float32)), 31), 0)) == 0)) || ((cse_var_2 - max(min((cast(int32, @tir.floor(cse_var_3, dtype=float32)) + 1), 31), 0)) == 0)) || ((cse_var_2 - max(min((cast(int32, @tir.floor(cse_var_3, dtype=float32)) + 2), 31), 0)) == 0)))) {
let cse_var_4: int32 = ((((ax0.ax1.fused*1024) + (ax2*32)) + (ax3.outer*16)) + ax3.inner.s)
resize.inp.grad_3[cse_var_4] = (resize.inp.grad_3[cse_var_4] + (dy_3: Buffer(dy_2, float32, [196608], [])[(((ax0.ax1.fused*2048) + (n0_n0_k2.shifted.shifted*32)) + n1_n1_k3.shifted.shifted)]*(select((((((((ax2 == max(min((cast(int32, @tir.floor(cse_var_1, dtype=float32)) - 1), 31), 0)) && (cse_var_2 == max(min((cast(int32, @tir.floor(cse_var_3, dtype=float32)) - 1), 31), 0))) || ((ax2 == max(min((cast(int32, @tir.floor(cse_var_1, dtype=float32)) - 1), 31), 0)) && (cse_var_2 == max(min(cast(int32, @tir.floor(cse_var_3, dtype=float32)), 31), 0)))) || ((ax2 == max(min((cast(int32, @tir.floor(cse_var_1, dtype=float32)) - 1), 31), 0)) && (cse_var_2 == max(min((cast(int32, @tir.floor(cse_var_3, dtype=float32)) + 1), 31), 0)))) || ((ax2 == max(min((cast(int32, @tir.floor(cse_var_1, dtype=float32)) - 1), 31), 0)) && (cse_var_2 == max(min((cast(int32, @tir.floor(cse_var_3, dtype=float32)) + 2), 31), 0)))) || (((((ax2 == max(min(cast(int32, @tir.floor(cse_var_1, dtype=float32)), 31), 0)) && (cse_var_2 == max(min((cast(int32, @tir.floor(cse_var_3, dtype=float32)) - 1), 31), 0))) || ((ax2 == max(min(cast(int32, @tir.floor(cse_var_1, dtype=float32)), 31), 0)) && (cse_var_2 == max(min(cast(int32, @tir.floor(cse_var_3, dtype=float32)), 31), 0)))) || ((ax2 == max(min(cast(int32, @tir.floor(cse_var_1, dtype=float32)), 31), 0)) && (cse_var_2 == max(min((cast(int32, @tir.floor(cse_var_3, dtype=float32)) + 1), 31), 0)))) || ((ax2 == max(min(cast(int32, @tir.floor(cse_var_1, dtype=float32)), 31), 0)) && (cse_var_2 == max(min((cast(int32, @tir.floor(cse_var_3, dtype=float32)) + 2), 31), 0))))) || (((((ax2 == max(min((cast(int32, @tir.floor(cse_var_1, dtype=float32)) + 1), 31), 0)) && (cse_var_2 == max(min((cast(int32, @tir.floor(cse_var_3, dtype=float32)) - 1), 31), 0))) || ((ax2 == max(min((cast(int32, @tir.floor(cse_var_1, dtype=float32)) + 1), 31), 0)) && (cse_var_2 == max(min(cast(int32, @tir.floor(cse_var_3, dtype=float32)), 31), 0)))) || ((ax2 == max(min((cast(int32, @tir.floor(cse_var_1, dtype=float32)) + 1), 31), 0)) && (cse_var_2 == max(min((cast(int32, @tir.floor(cse_var_3, dtype=float32)) + 1), 31), 0)))) || ((ax2 == max(min((cast(int32, @tir.floor(cse_var_1, dtype=float32)) + 1), 31), 0)) && (cse_var_2 == max(min((cast(int32, @tir.floor(cse_var_3, dtype=float32)) + 2), 31), 0))))), (select(((((((ax2 == max(min((cast(int32, @tir.floor(cse_var_1, dtype=float32)) - 1), 31), 0)) && (cse_var_2 == max(min((cast(int32, @tir.floor(cse_var_3, dtype=float32)) - 1), 31), 0))) || ((ax2 == max(min((cast(int32, @tir.floor(cse_var_1, dtype=float32)) - 1), 31), 0)) && (cse_var_2 == max(min(cast(int32, @tir.floor(cse_var_3, dtype=float32)), 31), 0)))) || ((ax2 == max(min((cast(int32, @tir.floor(cse_var_1, dtype=float32)) - 1), 31), 0)) && (cse_var_2 == max(min((cast(int32, @tir.floor(cse_var_3, dtype=float32)) + 1), 31), 0)))) || ((ax2 == max(min((cast(int32, @tir.floor(cse_var_1, dtype=float32)) - 1), 31), 0)) && (cse_var_2 == max(min((cast(int32, @tir.floor(cse_var_3, dtype=float32)) + 2), 31), 0)))) || (((((ax2 == max(min(cast(int32, @tir.floor(cse_var_1, dtype=float32)), 31), 0)) && (cse_var_2 == max(min((cast(int32, @tir.floor(cse_var_3, dtype=float32)) - 1), 31), 0))) || ((ax2 == max(min(cast(int32, @tir.floor(cse_var_1, dtype=float32)), 31), 0)) && (cse_var_2 == max(min(cast(int32, @tir.floor(cse_var_3, dtype=float32)), 31), 0)))) || ((ax2 == max(min(cast(int32, @tir.floor(cse_var_1, dtype=float32)), 31), 0)) && (cse_var_2 == max(min((cast(int32, @tir.floor(cse_var_3, dtype=float32)) + 1), 31), 0)))) || ((ax2 == max(min(cast(int32, @tir.floor(cse_var_1, dtype=float32)), 31), 0)) && (cse_var_2 == max(min((cast(int32, @tir.floor(cse_var_3, dtype=float32)) + 2), 31), 0))))), (select((((((ax2 == max(min((cast(int32, @tir.floor(cse_var_1, dtype=float32)) - 1), 31), 0)) && (cse_var_2 == max(min((cast(int32, @tir.floor(cse_var_3, dtype=float32)) - 1), 31), 0))) || ((ax2 == max(min((cast(int32, @tir.floor(cse_var_1, dtype=float32)) - 1), 31), 0)) && (cse_var_2 == max(min(cast(int32, @tir.floor(cse_var_3, dtype=float32)), 31), 0)))) || ((ax2 == max(min((cast(int32, @tir.floor(cse_var_1, dtype=float32)) - 1), 31), 0)) && (cse_var_2 == max(min((cast(int32, @tir.floor(cse_var_3, dtype=float32)) + 1), 31), 0)))) || ((ax2 == max(min((cast(int32, @tir.floor(cse_var_1, dtype=float32)) - 1), 31), 0)) && (cse_var_2 == max(min((cast(int32, @tir.floor(cse_var_3, dtype=float32)) + 2), 31), 0)))), ((select((((cse_var_2 == max(min((cast(int32, @tir.floor(cse_var_3, dtype=float32)) - 1), 31), 0)) || (cse_var_2 == max(min(cast(int32, @tir.floor(cse_var_3, dtype=float32)), 31), 0))) || (cse_var_2 == max(min((cast(int32, @tir.floor(cse_var_3, dtype=float32)) + 1), 31), 0))), (select(((cse_var_2 == max(min((cast(int32, @tir.floor(cse_var_3, dtype=float32)) - 1), 31), 0)) || (cse_var_2 == max(min(cast(int32, @tir.floor(cse_var_3, dtype=float32)), 31), 0))), (select((cse_var_2 == max(min((cast(int32, @tir.floor(cse_var_3, dtype=float32)) - 1), 31), 0)), (-0.75f32*(((((cse_var_3 - @tir.floor(cse_var_3, dtype=float32))*(cse_var_3 - @tir.floor(cse_var_3, dtype=float32)))*(cse_var_3 - @tir.floor(cse_var_3, dtype=float32))) - (2f32*((cse_var_3 - @tir.floor(cse_var_3, dtype=float32))*(cse_var_3 - @tir.floor(cse_var_3, dtype=float32))))) + (cse_var_3 - @tir.floor(cse_var_3, dtype=float32)))), 0f32) + select((cse_var_2 == max(min(cast(int32, @tir.floor(cse_var_3, dtype=float32)), 31), 0)), (((1.25f32*(((cse_var_3 - @tir.floor(cse_var_3, dtype=float32))*(cse_var_3 - @tir.floor(cse_var_3, dtype=float32)))*(cse_var_3 - @tir.floor(cse_var_3, dtype=float32)))) - (2.25f32*((cse_var_3 - @tir.floor(cse_var_3, dtype=float32))*(cse_var_3 - @tir.floor(cse_var_3, dtype=float32))))) + 1f32), 0f32)), 0f32) + select((cse_var_2 == max(min((cast(int32, @tir.floor(cse_var_3, dtype=float32)) + 1), 31), 0)), (((-1.25f32*(((cse_var_3 - @tir.floor(cse_var_3, dtype=float32))*(cse_var_3 - @tir.floor(cse_var_3, dtype=float32)))*(cse_var_3 - @tir.floor(cse_var_3, dtype=float32)))) + (1.5f32*((cse_var_3 - @tir.floor(cse_var_3, dtype=float32))*(cse_var_3 - @tir.floor(cse_var_3, dtype=float32))))) - (-0.75f32*(cse_var_3 - @tir.floor(cse_var_3, dtype=float32)))), 0f32)), 0f32) + select((cse_var_2 == max(min((cast(int32, @tir.floor(cse_var_3, dtype=float32)) + 2), 31), 0)), ((0.75f32*(((cse_var_3 - @tir.floor(cse_var_3, dtype=float32))*(cse_var_3 - @tir.floor(cse_var_3, dtype=float32)))*(cse_var_3 - @tir.floor(cse_var_3, dtype=float32)))) + (-0.75f32*((cse_var_3 - @tir.floor(cse_var_3, dtype=float32))*(cse_var_3 - @tir.floor(cse_var_3, dtype=float32))))), 0f32))*(-0.75f32*(((((cse_var_1 - @tir.floor(cse_var_1, dtype=float32))*(cse_var_1 - @tir.floor(cse_var_1, dtype=float32)))*(cse_var_1 - @tir.floor(cse_var_1, dtype=float32))) - (2f32*((cse_var_1 - @tir.floor(cse_var_1, dtype=float32))*(cse_var_1 - @tir.floor(cse_var_1, dtype=float32))))) + (cse_var_1 - @tir.floor(cse_var_1, dtype=float32))))), 0f32) + select((((((ax2 == max(min(cast(int32, @tir.floor(cse_var_1, dtype=float32)), 31), 0)) && (cse_var_2 == max(min((cast(int32, @tir.floor(cse_var_3, dtype=float32)) - 1), 31), 0))) || ((ax2 == max(min(cast(int32, @tir.floor(cse_var_1, dtype=float32)), 31), 0)) && (cse_var_2 == max(min(cast(int32, @tir.floor(cse_var_3, dtype=float32)), 31), 0)))) || ((ax2 == max(min(cast(int32, @tir.floor(cse_var_1, dtype=float32)), 31), 0)) && (cse_var_2 == max(min((cast(int32, @tir.floor(cse_var_3, dtype=float32)) + 1), 31), 0)))) || ((ax2 == max(min(cast(int32, @tir.floor(cse_var_1, dtype=float32)), 31), 0)) && (cse_var_2 == max(min((cast(int32, @tir.floor(cse_var_3, dtype=float32)) + 2), 31), 0)))), ((select((((cse_var_2 == max(min((cast(int32, @tir.floor(cse_var_3, dtype=float32)) - 1), 31), 0)) || (cse_var_2 == max(min(cast(int32, @tir.floor(cse_var_3, dtype=float32)), 31), 0))) || (cse_var_2 == max(min((cast(int32, @tir.floor(cse_var_3, dtype=float32)) + 1), 31), 0))), (select(((cse_var_2 == max(min((cast(int32, @tir.floor(cse_var_3, dtype=float32)) - 1), 31), 0)) || (cse_var_2 == max(min(cast(int32, @tir.floor(cse_var_3, dtype=float32)), 31), 0))), (select((cse_var_2 == max(min((cast(int32, @tir.floor(cse_var_3, dtype=float32)) - 1), 31), 0)), (-0.75f32*(((((cse_var_3 - @tir.floor(cse_var_3, dtype=float32))*(cse_var_3 - @tir.floor(cse_var_3, dtype=float32)))*(cse_var_3 - @tir.floor(cse_var_3, dtype=float32))) - (2f32*((cse_var_3 - @tir.floor(cse_var_3, dtype=float32))*(cse_var_3 - @tir.floor(cse_var_3, dtype=float32))))) + (cse_var_3 - @tir.floor(cse_var_3, dtype=float32)))), 0f32) + select((cse_var_2 == max(min(cast(int32, @tir.floor(cse_var_3, dtype=float32)), 31), 0)), (((1.25f32*(((cse_var_3 - @tir.floor(cse_var_3, dtype=float32))*(cse_var_3 - @tir.floor(cse_var_3, dtype=float32)))*(cse_var_3 - @tir.floor(cse_var_3, dtype=float32)))) - (2.25f32*((cse_var_3 - @tir.floor(cse_var_3, dtype=float32))*(cse_var_3 - @tir.floor(cse_var_3, dtype=float32))))) + 1f32), 0f32)), 0f32) + select((cse_var_2 == max(min((cast(int32, @tir.floor(cse_var_3, dtype=float32)) + 1), 31), 0)), (((-1.25f32*(((cse_var_3 - @tir.floor(cse_var_3, dtype=float32))*(cse_var_3 - @tir.floor(cse_var_3, dtype=float32)))*(cse_var_3 - @tir.floor(cse_var_3, dtype=float32)))) + (1.5f32*((cse_var_3 - @tir.floor(cse_var_3, dtype=float32))*(cse_var_3 - @tir.floor(cse_var_3, dtype=float32))))) - (-0.75f32*(cse_var_3 - @tir.floor(cse_var_3, dtype=float32)))), 0f32)), 0f32) + select((cse_var_2 == max(min((cast(int32, @tir.floor(cse_var_3, dtype=float32)) + 2), 31), 0)), ((0.75f32*(((cse_var_3 - @tir.floor(cse_var_3, dtype=float32))*(cse_var_3 - @tir.floor(cse_var_3, dtype=float32)))*(cse_var_3 - @tir.floor(cse_var_3, dtype=float32)))) + (-0.75f32*((cse_var_3 - @tir.floor(cse_var_3, dtype=float32))*(cse_var_3 - @tir.floor(cse_var_3, dtype=float32))))), 0f32))*(((1.25f32*(((cse_var_1 - @tir.floor(cse_var_1, dtype=float32))*(cse_var_1 - @tir.floor(cse_var_1, dtype=float32)))*(cse_var_1 - @tir.floor(cse_var_1, dtype=float32)))) - (2.25f32*((cse_var_1 - @tir.floor(cse_var_1, dtype=float32))*(cse_var_1 - @tir.floor(cse_var_1, dtype=float32))))) + 1f32)), 0f32)), 0f32) + select((((((ax2 == max(min((cast(int32, @tir.floor(cse_var_1, dtype=float32)) + 1), 31), 0)) && (cse_var_2 == max(min((cast(int32, @tir.floor(cse_var_3, dtype=float32)) - 1), 31), 0))) || ((ax2 == max(min((cast(int32, @tir.floor(cse_var_1, dtype=float32)) + 1), 31), 0)) && (cse_var_2 == max(min(cast(int32, @tir.floor(cse_var_3, dtype=float32)), 31), 0)))) || ((ax2 == max(min((cast(int32, @tir.floor(cse_var_1, dtype=float32)) + 1), 31), 0)) && (cse_var_2 == max(min((cast(int32, @tir.floor(cse_var_3, dtype=float32)) + 1), 31), 0)))) || ((ax2 == max(min((cast(int32, @tir.floor(cse_var_1, dtype=float32)) + 1), 31), 0)) && (cse_var_2 == max(min((cast(int32, @tir.floor(cse_var_3, dtype=float32)) + 2), 31), 0)))), ((select((((cse_var_2 == max(min((cast(int32, @tir.floor(cse_var_3, dtype=float32)) - 1), 31), 0)) || (cse_var_2 == max(min(cast(int32, @tir.floor(cse_var_3, dtype=float32)), 31), 0))) || (cse_var_2 == max(min((cast(int32, @tir.floor(cse_var_3, dtype=float32)) + 1), 31), 0))), (select(((cse_var_2 == max(min((cast(int32, @tir.floor(cse_var_3, dtype=float32)) - 1), 31), 0)) || (cse_var_2 == max(min(cast(int32, @tir.floor(cse_var_3, dtype=float32)), 31), 0))), (select((cse_var_2 == max(min((cast(int32, @tir.floor(cse_var_3, dtype=float32)) - 1), 31), 0)), (-0.75f32*(((((cse_var_3 - @tir.floor(cse_var_3, dtype=float32))*(cse_var_3 - @tir.floor(cse_var_3, dtype=float32)))*(cse_var_3 - @tir.floor(cse_var_3, dtype=float32))) - (2f32*((cse_var_3 - @tir.floor(cse_var_3, dtype=float32))*(cse_var_3 - @tir.floor(cse_var_3, dtype=float32))))) + (cse_var_3 - @tir.floor(cse_var_3, dtype=float32)))), 0f32) + select((cse_var_2 == max(min(cast(int32, @tir.floor(cse_var_3, dtype=float32)), 31), 0)), (((1.25f32*(((cse_var_3 - @tir.floor(cse_var_3, dtype=float32))*(cse_var_3 - @tir.floor(cse_var_3, dtype=float32)))*(cse_var_3 - @tir.floor(cse_var_3, dtype=float32)))) - (2.25f32*((cse_var_3 - @tir.floor(cse_var_3, dtype=float32))*(cse_var_3 - @tir.floor(cse_var_3, dtype=float32))))) + 1f32), 0f32)), 0f32) + select((cse_var_2 == max(min((cast(int32, @tir.floor(cse_var_3, dtype=float32)) + 1), 31), 0)), (((-1.25f32*(((cse_var_3 - @tir.floor(cse_var_3, dtype=float32))*(cse_var_3 - @tir.floor(cse_var_3, dtype=float32)))*(cse_var_3 - @tir.floor(cse_var_3, dtype=float32)))) + (1.5f32*((cse_var_3 - @tir.floor(cse_var_3, dtype=float32))*(cse_var_3 - @tir.floor(cse_var_3, dtype=float32))))) - (-0.75f32*(cse_var_3 - @tir.floor(cse_var_3, dtype=float32)))), 0f32)), 0f32) + select((cse_var_2 == max(min((cast(int32, @tir.floor(cse_var_3, dtype=float32)) + 2), 31), 0)), ((0.75f32*(((cse_var_3 - @tir.floor(cse_var_3, dtype=float32))*(cse_var_3 - @tir.floor(cse_var_3, dtype=float32)))*(cse_var_3 - @tir.floor(cse_var_3, dtype=float32)))) + (-0.75f32*((cse_var_3 - @tir.floor(cse_var_3, dtype=float32))*(cse_var_3 - @tir.floor(cse_var_3, dtype=float32))))), 0f32))*(((-1.25f32*(((cse_var_1 - @tir.floor(cse_var_1, dtype=float32))*(cse_var_1 - @tir.floor(cse_var_1, dtype=float32)))*(cse_var_1 - @tir.floor(cse_var_1, dtype=float32)))) + (1.5f32*((cse_var_1 - @tir.floor(cse_var_1, dtype=float32))*(cse_var_1 - @tir.floor(cse_var_1, dtype=float32))))) - (-0.75f32*(cse_var_1 - @tir.floor(cse_var_1, dtype=float32))))), 0f32)), 0f32) + select((((((ax2 == max(min((cast(int32, @tir.floor(cse_var_1, dtype=float32)) + 2), 31), 0)) && (cse_var_2 == max(min((cast(int32, @tir.floor(cse_var_3, dtype=float32)) - 1), 31), 0))) || ((ax2 == max(min((cast(int32, @tir.floor(cse_var_1, dtype=float32)) + 2), 31), 0)) && (cse_var_2 == max(min(cast(int32, @tir.floor(cse_var_3, dtype=float32)), 31), 0)))) || ((ax2 == max(min((cast(int32, @tir.floor(cse_var_1, dtype=float32)) + 2), 31), 0)) && (cse_var_2 == max(min((cast(int32, @tir.floor(cse_var_3, dtype=float32)) + 1), 31), 0)))) || ((ax2 == max(min((cast(int32, @tir.floor(cse_var_1, dtype=float32)) + 2), 31), 0)) && (cse_var_2 == max(min((cast(int32, @tir.floor(cse_var_3, dtype=float32)) + 2), 31), 0)))), ((select((((cse_var_2 == max(min((cast(int32, @tir.floor(cse_var_3, dtype=float32)) - 1), 31), 0)) || (cse_var_2 == max(min(cast(int32, @tir.floor(cse_var_3, dtype=float32)), 31), 0))) || (cse_var_2 == max(min((cast(int32, @tir.floor(cse_var_3, dtype=float32)) + 1), 31), 0))), (select(((cse_var_2 == max(min((cast(int32, @tir.floor(cse_var_3, dtype=float32)) - 1), 31), 0)) || (cse_var_2 == max(min(cast(int32, @tir.floor(cse_var_3, dtype=float32)), 31), 0))), (select((cse_var_2 == max(min((cast(int32, @tir.floor(cse_var_3, dtype=float32)) - 1), 31), 0)), (-0.75f32*(((((cse_var_3 - @tir.floor(cse_var_3, dtype=float32))*(cse_var_3 - @tir.floor(cse_var_3, dtype=float32)))*(cse_var_3 - @tir.floor(cse_var_3, dtype=float32))) - (2f32*((cse_var_3 - @tir.floor(cse_var_3, dtype=float32))*(cse_var_3 - @tir.floor(cse_var_3, dtype=float32))))) + (cse_var_3 - @tir.floor(cse_var_3, dtype=float32)))), 0f32) + select((cse_var_2 == max(min(cast(int32, @tir.floor(cse_var_3, dtype=float32)), 31), 0)), (((1.25f32*(((cse_var_3 - @tir.floor(cse_var_3, dtype=float32))*(cse_var_3 - @tir.floor(cse_var_3, dtype=float32)))*(cse_var_3 - @tir.floor(cse_var_3, dtype=float32)))) - (2.25f32*((cse_var_3 - @tir.floor(cse_var_3, dtype=float32))*(cse_var_3 - @tir.floor(cse_var_3, dtype=float32))))) + 1f32), 0f32)), 0f32) + select((cse_var_2 == max(min((cast(int32, @tir.floor(cse_var_3, dtype=float32)) + 1), 31), 0)), (((-1.25f32*(((cse_var_3 - @tir.floor(cse_var_3, dtype=float32))*(cse_var_3 - @tir.floor(cse_var_3, dtype=float32)))*(cse_var_3 - @tir.floor(cse_var_3, dtype=float32)))) + (1.5f32*((cse_var_3 - @tir.floor(cse_var_3, dtype=float32))*(cse_var_3 - @tir.floor(cse_var_3, dtype=float32))))) - (-0.75f32*(cse_var_3 - @tir.floor(cse_var_3, dtype=float32)))), 0f32)), 0f32) + select((cse_var_2 == max(min((cast(int32, @tir.floor(cse_var_3, dtype=float32)) + 2), 31), 0)), ((0.75f32*(((cse_var_3 - @tir.floor(cse_var_3, dtype=float32))*(cse_var_3 - @tir.floor(cse_var_3, dtype=float32)))*(cse_var_3 - @tir.floor(cse_var_3, dtype=float32)))) + (-0.75f32*((cse_var_3 - @tir.floor(cse_var_3, dtype=float32))*(cse_var_3 - @tir.floor(cse_var_3, dtype=float32))))), 0f32))*((0.75f32*(((cse_var_1 - @tir.floor(cse_var_1, dtype=float32))*(cse_var_1 - @tir.floor(cse_var_1, dtype=float32)))*(cse_var_1 - @tir.floor(cse_var_1, dtype=float32)))) + (-0.75f32*((cse_var_1 - @tir.floor(cse_var_1, dtype=float32))*(cse_var_1 - @tir.floor(cse_var_1, dtype=float32)))))), 0f32))))
}
}
}
}
}
}
}
}
With this commit:
@main = primfn(inp_1: handle, dy_1: handle, resize.inp.grad_1: handle) -> ()
attr = {"from_legacy_te_schedule": True, "global_symbol": "main", "tir.noalias": True}
buffers = {inp: Buffer(inp_2: Pointer(float32), float32, [32, 3, 32, 32], []),
dy: Buffer(dy_2: Pointer(float32), float32, [32, 3, 64, 32], []),
resize.inp.grad: Buffer(resize.inp.grad_2: Pointer(float32), float32, [32, 3, 32, 32], [])}
buffer_map = {inp_1: inp, dy_1: dy, resize.inp.grad_1: resize.inp.grad} {
for (ax0.ax1.fused: int32, 0, 96) "parallel" {
for (ax2: int32, 0, 32) {
for (ax3.outer: int32, 0, 2) {
resize.inp.grad_3: Buffer(resize.inp.grad_2, float32, [98304], [])[ramp((((ax0.ax1.fused*1024) + (ax2*32)) + (ax3.outer*16)), 1, 16)] = broadcast(0f32, 16)
for (n0_n0_k2.shifted.shifted: int32, 0, 64) {
for (n1_n1_k3.shifted.shifted: int32, 0, 32) {
for (ax3.inner.s: int32, 0, 16) {
let cse_var_3: float32 = cast(float32, n1_n1_k3.shifted.shifted)
let cse_var_2: int32 = ((ax3.outer*16) + ax3.inner.s)
let cse_var_1: float32 = (((cast(float32, n0_n0_k2.shifted.shifted) + 0.5f32)*0.5f32) - 0.5f32)
if ((((((ax2 - max(min((cast(int32, @tir.floor(cse_var_1, dtype=float32)) - 1), 31), 0)) == 0) && (((((cse_var_2 - max(min((cast(int32, @tir.floor(cse_var_3, dtype=float32)) - 1), 31), 0)) == 0) || ((cse_var_2 - max(min(cast(int32, @tir.floor(cse_var_3, dtype=float32)), 31), 0)) == 0)) || ((cse_var_2 - max(min((cast(int32, @tir.floor(cse_var_3, dtype=float32)) + 1), 31), 0)) == 0)) || ((cse_var_2 - max(min((cast(int32, @tir.floor(cse_var_3, dtype=float32)) + 2), 31), 0)) == 0))) || (((ax2 - max(min(cast(int32, @tir.floor(cse_var_1, dtype=float32)), 31), 0)) == 0) && (((((cse_var_2 - max(min((cast(int32, @tir.floor(cse_var_3, dtype=float32)) - 1), 31), 0)) == 0) || ((cse_var_2 - max(min(cast(int32, @tir.floor(cse_var_3, dtype=float32)), 31), 0)) == 0)) || ((cse_var_2 - max(min((cast(int32, @tir.floor(cse_var_3, dtype=float32)) + 1), 31), 0)) == 0)) || ((cse_var_2 - max(min((cast(int32, @tir.floor(cse_var_3, dtype=float32)) + 2), 31), 0)) == 0)))) || (((ax2 - max(min((cast(int32, @tir.floor(cse_var_1, dtype=float32)) + 1), 31), 0)) == 0) && (((((cse_var_2 - max(min((cast(int32, @tir.floor(cse_var_3, dtype=float32)) - 1), 31), 0)) == 0) || ((cse_var_2 - max(min(cast(int32, @tir.floor(cse_var_3, dtype=float32)), 31), 0)) == 0)) || ((cse_var_2 - max(min((cast(int32, @tir.floor(cse_var_3, dtype=float32)) + 1), 31), 0)) == 0)) || ((cse_var_2 - max(min((cast(int32, @tir.floor(cse_var_3, dtype=float32)) + 2), 31), 0)) == 0)))) || (((ax2 - max(min((cast(int32, @tir.floor(cse_var_1, dtype=float32)) + 2), 31), 0)) == 0) && (((((cse_var_2 - max(min((cast(int32, @tir.floor(cse_var_3, dtype=float32)) - 1), 31), 0)) == 0) || ((cse_var_2 - max(min(cast(int32, @tir.floor(cse_var_3, dtype=float32)), 31), 0)) == 0)) || ((cse_var_2 - max(min((cast(int32, @tir.floor(cse_var_3, dtype=float32)) + 1), 31), 0)) == 0)) || ((cse_var_2 - max(min((cast(int32, @tir.floor(cse_var_3, dtype=float32)) + 2), 31), 0)) == 0)))) {
let cse_var_4: int32 = ((((ax0.ax1.fused*1024) + (ax2*32)) + (ax3.outer*16)) + ax3.inner.s)
resize.inp.grad_3[cse_var_4] = (resize.inp.grad_3[cse_var_4] + (dy_3: Buffer(dy_2, float32, [196608], [])[(((ax0.ax1.fused*2048) + (n0_n0_k2.shifted.shifted*32)) + n1_n1_k3.shifted.shifted)]*(select((((((((ax2 == max(min((cast(int32, @tir.floor(cse_var_1, dtype=float32)) - 1), 31), 0)) && (cse_var_2 == max(min((cast(int32, @tir.floor(cse_var_3, dtype=float32)) - 1), 31), 0))) || ((ax2 == max(min((cast(int32, @tir.floor(cse_var_1, dtype=float32)) - 1), 31), 0)) && (cse_var_2 == max(min(cast(int32, @tir.floor(cse_var_3, dtype=float32)), 31), 0)))) || ((ax2 == max(min((cast(int32, @tir.floor(cse_var_1, dtype=float32)) - 1), 31), 0)) && (cse_var_2 == max(min((cast(int32, @tir.floor(cse_var_3, dtype=float32)) + 1), 31), 0)))) || ((ax2 == max(min((cast(int32, @tir.floor(cse_var_1, dtype=float32)) - 1), 31), 0)) && (cse_var_2 == max(min((cast(int32, @tir.floor(cse_var_3, dtype=float32)) + 2), 31), 0)))) || (((((ax2 == max(min(cast(int32, @tir.floor(cse_var_1, dtype=float32)), 31), 0)) && (cse_var_2 == max(min((cast(int32, @tir.floor(cse_var_3, dtype=float32)) - 1), 31), 0))) || ((ax2 == max(min(cast(int32, @tir.floor(cse_var_1, dtype=float32)), 31), 0)) && (cse_var_2 == max(min(cast(int32, @tir.floor(cse_var_3, dtype=float32)), 31), 0)))) || ((ax2 == max(min(cast(int32, @tir.floor(cse_var_1, dtype=float32)), 31), 0)) && (cse_var_2 == max(min((cast(int32, @tir.floor(cse_var_3, dtype=float32)) + 1), 31), 0)))) || ((ax2 == max(min(cast(int32, @tir.floor(cse_var_1, dtype=float32)), 31), 0)) && (cse_var_2 == max(min((cast(int32, @tir.floor(cse_var_3, dtype=float32)) + 2), 31), 0))))) || (((((ax2 == max(min((cast(int32, @tir.floor(cse_var_1, dtype=float32)) + 1), 31), 0)) && (cse_var_2 == max(min((cast(int32, @tir.floor(cse_var_3, dtype=float32)) - 1), 31), 0))) || ((ax2 == max(min((cast(int32, @tir.floor(cse_var_1, dtype=float32)) + 1), 31), 0)) && (cse_var_2 == max(min(cast(int32, @tir.floor(cse_var_3, dtype=float32)), 31), 0)))) || ((ax2 == max(min((cast(int32, @tir.floor(cse_var_1, dtype=float32)) + 1), 31), 0)) && (cse_var_2 == max(min((cast(int32, @tir.floor(cse_var_3, dtype=float32)) + 1), 31), 0)))) || ((ax2 == max(min((cast(int32, @tir.floor(cse_var_1, dtype=float32)) + 1), 31), 0)) && (cse_var_2 == max(min((cast(int32, @tir.floor(cse_var_3, dtype=float32)) + 2), 31), 0))))), (select(((((((ax2 == max(min((cast(int32, @tir.floor(cse_var_1, dtype=float32)) - 1), 31), 0)) && (cse_var_2 == max(min((cast(int32, @tir.floor(cse_var_3, dtype=float32)) - 1), 31), 0))) || ((ax2 == max(min((cast(int32, @tir.floor(cse_var_1, dtype=float32)) - 1), 31), 0)) && (cse_var_2 == max(min(cast(int32, @tir.floor(cse_var_3, dtype=float32)), 31), 0)))) || ((ax2 == max(min((cast(int32, @tir.floor(cse_var_1, dtype=float32)) - 1), 31), 0)) && (cse_var_2 == max(min((cast(int32, @tir.floor(cse_var_3, dtype=float32)) + 1), 31), 0)))) || ((ax2 == max(min((cast(int32, @tir.floor(cse_var_1, dtype=float32)) - 1), 31), 0)) && (cse_var_2 == max(min((cast(int32, @tir.floor(cse_var_3, dtype=float32)) + 2), 31), 0)))) || (((((ax2 == max(min(cast(int32, @tir.floor(cse_var_1, dtype=float32)), 31), 0)) && (cse_var_2 == max(min((cast(int32, @tir.floor(cse_var_3, dtype=float32)) - 1), 31), 0))) || ((ax2 == max(min(cast(int32, @tir.floor(cse_var_1, dtype=float32)), 31), 0)) && (cse_var_2 == max(min(cast(int32, @tir.floor(cse_var_3, dtype=float32)), 31), 0)))) || ((ax2 == max(min(cast(int32, @tir.floor(cse_var_1, dtype=float32)), 31), 0)) && (cse_var_2 == max(min((cast(int32, @tir.floor(cse_var_3, dtype=float32)) + 1), 31), 0)))) || ((ax2 == max(min(cast(int32, @tir.floor(cse_var_1, dtype=float32)), 31), 0)) && (cse_var_2 == max(min((cast(int32, @tir.floor(cse_var_3, dtype=float32)) + 2), 31), 0))))), (select((((((ax2 == max(min((cast(int32, @tir.floor(cse_var_1, dtype=float32)) - 1), 31), 0)) && (cse_var_2 == max(min((cast(int32, @tir.floor(cse_var_3, dtype=float32)) - 1), 31), 0))) || ((ax2 == max(min((cast(int32, @tir.floor(cse_var_1, dtype=float32)) - 1), 31), 0)) && (cse_var_2 == max(min(cast(int32, @tir.floor(cse_var_3, dtype=float32)), 31), 0)))) || ((ax2 == max(min((cast(int32, @tir.floor(cse_var_1, dtype=float32)) - 1), 31), 0)) && (cse_var_2 == max(min((cast(int32, @tir.floor(cse_var_3, dtype=float32)) + 1), 31), 0)))) || ((ax2 == max(min((cast(int32, @tir.floor(cse_var_1, dtype=float32)) - 1), 31), 0)) && (cse_var_2 == max(min((cast(int32, @tir.floor(cse_var_3, dtype=float32)) + 2), 31), 0)))), ((select((((cse_var_2 == max(min((cast(int32, @tir.floor(cse_var_3, dtype=float32)) - 1), 31), 0)) || (cse_var_2 == max(min(cast(int32, @tir.floor(cse_var_3, dtype=float32)), 31), 0))) || (cse_var_2 == max(min((cast(int32, @tir.floor(cse_var_3, dtype=float32)) + 1), 31), 0))), (select(((cse_var_2 == max(min((cast(int32, @tir.floor(cse_var_3, dtype=float32)) - 1), 31), 0)) || (cse_var_2 == max(min(cast(int32, @tir.floor(cse_var_3, dtype=float32)), 31), 0))), (select((cse_var_2 == max(min((cast(int32, @tir.floor(cse_var_3, dtype=float32)) - 1), 31), 0)), (-0.75f32*(((((cse_var_3 - @tir.floor(cse_var_3, dtype=float32))*(cse_var_3 - @tir.floor(cse_var_3, dtype=float32)))*(cse_var_3 - @tir.floor(cse_var_3, dtype=float32))) - (2f32*((cse_var_3 - @tir.floor(cse_var_3, dtype=float32))*(cse_var_3 - @tir.floor(cse_var_3, dtype=float32))))) + (cse_var_3 - @tir.floor(cse_var_3, dtype=float32)))), 0f32) + select((cse_var_2 == max(min(cast(int32, @tir.floor(cse_var_3, dtype=float32)), 31), 0)), (((1.25f32*(((cse_var_3 - @tir.floor(cse_var_3, dtype=float32))*(cse_var_3 - @tir.floor(cse_var_3, dtype=float32)))*(cse_var_3 - @tir.floor(cse_var_3, dtype=float32)))) - (2.25f32*((cse_var_3 - @tir.floor(cse_var_3, dtype=float32))*(cse_var_3 - @tir.floor(cse_var_3, dtype=float32))))) + 1f32), 0f32)), 0f32) + select((cse_var_2 == max(min((cast(int32, @tir.floor(cse_var_3, dtype=float32)) + 1), 31), 0)), (((-1.25f32*(((cse_var_3 - @tir.floor(cse_var_3, dtype=float32))*(cse_var_3 - @tir.floor(cse_var_3, dtype=float32)))*(cse_var_3 - @tir.floor(cse_var_3, dtype=float32)))) + (1.5f32*((cse_var_3 - @tir.floor(cse_var_3, dtype=float32))*(cse_var_3 - @tir.floor(cse_var_3, dtype=float32))))) - (-0.75f32*(cse_var_3 - @tir.floor(cse_var_3, dtype=float32)))), 0f32)), 0f32) + select((cse_var_2 == max(min((cast(int32, @tir.floor(cse_var_3, dtype=float32)) + 2), 31), 0)), ((0.75f32*(((cse_var_3 - @tir.floor(cse_var_3, dtype=float32))*(cse_var_3 - @tir.floor(cse_var_3, dtype=float32)))*(cse_var_3 - @tir.floor(cse_var_3, dtype=float32)))) + (-0.75f32*((cse_var_3 - @tir.floor(cse_var_3, dtype=float32))*(cse_var_3 - @tir.floor(cse_var_3, dtype=float32))))), 0f32))*(-0.75f32*(((((cse_var_1 - @tir.floor(cse_var_1, dtype=float32))*(cse_var_1 - @tir.floor(cse_var_1, dtype=float32)))*(cse_var_1 - @tir.floor(cse_var_1, dtype=float32))) - (2f32*((cse_var_1 - @tir.floor(cse_var_1, dtype=float32))*(cse_var_1 - @tir.floor(cse_var_1, dtype=float32))))) + (cse_var_1 - @tir.floor(cse_var_1, dtype=float32))))), 0f32) + select((((((ax2 == max(min(cast(int32, @tir.floor(cse_var_1, dtype=float32)), 31), 0)) && (cse_var_2 == max(min((cast(int32, @tir.floor(cse_var_3, dtype=float32)) - 1), 31), 0))) || ((ax2 == max(min(cast(int32, @tir.floor(cse_var_1, dtype=float32)), 31), 0)) && (cse_var_2 == max(min(cast(int32, @tir.floor(cse_var_3, dtype=float32)), 31), 0)))) || ((ax2 == max(min(cast(int32, @tir.floor(cse_var_1, dtype=float32)), 31), 0)) && (cse_var_2 == max(min((cast(int32, @tir.floor(cse_var_3, dtype=float32)) + 1), 31), 0)))) || ((ax2 == max(min(cast(int32, @tir.floor(cse_var_1, dtype=float32)), 31), 0)) && (cse_var_2 == max(min((cast(int32, @tir.floor(cse_var_3, dtype=float32)) + 2), 31), 0)))), ((select((((cse_var_2 == max(min((cast(int32, @tir.floor(cse_var_3, dtype=float32)) - 1), 31), 0)) || (cse_var_2 == max(min(cast(int32, @tir.floor(cse_var_3, dtype=float32)), 31), 0))) || (cse_var_2 == max(min((cast(int32, @tir.floor(cse_var_3, dtype=float32)) + 1), 31), 0))), (select(((cse_var_2 == max(min((cast(int32, @tir.floor(cse_var_3, dtype=float32)) - 1), 31), 0)) || (cse_var_2 == max(min(cast(int32, @tir.floor(cse_var_3, dtype=float32)), 31), 0))), (select((cse_var_2 == max(min((cast(int32, @tir.floor(cse_var_3, dtype=float32)) - 1), 31), 0)), (-0.75f32*(((((cse_var_3 - @tir.floor(cse_var_3, dtype=float32))*(cse_var_3 - @tir.floor(cse_var_3, dtype=float32)))*(cse_var_3 - @tir.floor(cse_var_3, dtype=float32))) - (2f32*((cse_var_3 - @tir.floor(cse_var_3, dtype=float32))*(cse_var_3 - @tir.floor(cse_var_3, dtype=float32))))) + (cse_var_3 - @tir.floor(cse_var_3, dtype=float32)))), 0f32) + select((cse_var_2 == max(min(cast(int32, @tir.floor(cse_var_3, dtype=float32)), 31), 0)), (((1.25f32*(((cse_var_3 - @tir.floor(cse_var_3, dtype=float32))*(cse_var_3 - @tir.floor(cse_var_3, dtype=float32)))*(cse_var_3 - @tir.floor(cse_var_3, dtype=float32)))) - (2.25f32*((cse_var_3 - @tir.floor(cse_var_3, dtype=float32))*(cse_var_3 - @tir.floor(cse_var_3, dtype=float32))))) + 1f32), 0f32)), 0f32) + select((cse_var_2 == max(min((cast(int32, @tir.floor(cse_var_3, dtype=float32)) + 1), 31), 0)), (((-1.25f32*(((cse_var_3 - @tir.floor(cse_var_3, dtype=float32))*(cse_var_3 - @tir.floor(cse_var_3, dtype=float32)))*(cse_var_3 - @tir.floor(cse_var_3, dtype=float32)))) + (1.5f32*((cse_var_3 - @tir.floor(cse_var_3, dtype=float32))*(cse_var_3 - @tir.floor(cse_var_3, dtype=float32))))) - (-0.75f32*(cse_var_3 - @tir.floor(cse_var_3, dtype=float32)))), 0f32)), 0f32) + select((cse_var_2 == max(min((cast(int32, @tir.floor(cse_var_3, dtype=float32)) + 2), 31), 0)), ((0.75f32*(((cse_var_3 - @tir.floor(cse_var_3, dtype=float32))*(cse_var_3 - @tir.floor(cse_var_3, dtype=float32)))*(cse_var_3 - @tir.floor(cse_var_3, dtype=float32)))) + (-0.75f32*((cse_var_3 - @tir.floor(cse_var_3, dtype=float32))*(cse_var_3 - @tir.floor(cse_var_3, dtype=float32))))), 0f32))*(((1.25f32*(((cse_var_1 - @tir.floor(cse_var_1, dtype=float32))*(cse_var_1 - @tir.floor(cse_var_1, dtype=float32)))*(cse_var_1 - @tir.floor(cse_var_1, dtype=float32)))) - (2.25f32*((cse_var_1 - @tir.floor(cse_var_1, dtype=float32))*(cse_var_1 - @tir.floor(cse_var_1, dtype=float32))))) + 1f32)), 0f32)), 0f32) + select((((((ax2 == max(min((cast(int32, @tir.floor(cse_var_1, dtype=float32)) + 1), 31), 0)) && (cse_var_2 == max(min((cast(int32, @tir.floor(cse_var_3, dtype=float32)) - 1), 31), 0))) || ((ax2 == max(min((cast(int32, @tir.floor(cse_var_1, dtype=float32)) + 1), 31), 0)) && (cse_var_2 == max(min(cast(int32, @tir.floor(cse_var_3, dtype=float32)), 31), 0)))) || ((ax2 == max(min((cast(int32, @tir.floor(cse_var_1, dtype=float32)) + 1), 31), 0)) && (cse_var_2 == max(min((cast(int32, @tir.floor(cse_var_3, dtype=float32)) + 1), 31), 0)))) || ((ax2 == max(min((cast(int32, @tir.floor(cse_var_1, dtype=float32)) + 1), 31), 0)) && (cse_var_2 == max(min((cast(int32, @tir.floor(cse_var_3, dtype=float32)) + 2), 31), 0)))), ((select((((cse_var_2 == max(min((cast(int32, @tir.floor(cse_var_3, dtype=float32)) - 1), 31), 0)) || (cse_var_2 == max(min(cast(int32, @tir.floor(cse_var_3, dtype=float32)), 31), 0))) || (cse_var_2 == max(min((cast(int32, @tir.floor(cse_var_3, dtype=float32)) + 1), 31), 0))), (select(((cse_var_2 == max(min((cast(int32, @tir.floor(cse_var_3, dtype=float32)) - 1), 31), 0)) || (cse_var_2 == max(min(cast(int32, @tir.floor(cse_var_3, dtype=float32)), 31), 0))), (select((cse_var_2 == max(min((cast(int32, @tir.floor(cse_var_3, dtype=float32)) - 1), 31), 0)), (-0.75f32*(((((cse_var_3 - @tir.floor(cse_var_3, dtype=float32))*(cse_var_3 - @tir.floor(cse_var_3, dtype=float32)))*(cse_var_3 - @tir.floor(cse_var_3, dtype=float32))) - (2f32*((cse_var_3 - @tir.floor(cse_var_3, dtype=float32))*(cse_var_3 - @tir.floor(cse_var_3, dtype=float32))))) + (cse_var_3 - @tir.floor(cse_var_3, dtype=float32)))), 0f32) + select((cse_var_2 == max(min(cast(int32, @tir.floor(cse_var_3, dtype=float32)), 31), 0)), (((1.25f32*(((cse_var_3 - @tir.floor(cse_var_3, dtype=float32))*(cse_var_3 - @tir.floor(cse_var_3, dtype=float32)))*(cse_var_3 - @tir.floor(cse_var_3, dtype=float32)))) - (2.25f32*((cse_var_3 - @tir.floor(cse_var_3, dtype=float32))*(cse_var_3 - @tir.floor(cse_var_3, dtype=float32))))) + 1f32), 0f32)), 0f32) + select((cse_var_2 == max(min((cast(int32, @tir.floor(cse_var_3, dtype=float32)) + 1), 31), 0)), (((-1.25f32*(((cse_var_3 - @tir.floor(cse_var_3, dtype=float32))*(cse_var_3 - @tir.floor(cse_var_3, dtype=float32)))*(cse_var_3 - @tir.floor(cse_var_3, dtype=float32)))) + (1.5f32*((cse_var_3 - @tir.floor(cse_var_3, dtype=float32))*(cse_var_3 - @tir.floor(cse_var_3, dtype=float32))))) - (-0.75f32*(cse_var_3 - @tir.floor(cse_var_3, dtype=float32)))), 0f32)), 0f32) + select((cse_var_2 == max(min((cast(int32, @tir.floor(cse_var_3, dtype=float32)) + 2), 31), 0)), ((0.75f32*(((cse_var_3 - @tir.floor(cse_var_3, dtype=float32))*(cse_var_3 - @tir.floor(cse_var_3, dtype=float32)))*(cse_var_3 - @tir.floor(cse_var_3, dtype=float32)))) + (-0.75f32*((cse_var_3 - @tir.floor(cse_var_3, dtype=float32))*(cse_var_3 - @tir.floor(cse_var_3, dtype=float32))))), 0f32))*(((-1.25f32*(((cse_var_1 - @tir.floor(cse_var_1, dtype=float32))*(cse_var_1 - @tir.floor(cse_var_1, dtype=float32)))*(cse_var_1 - @tir.floor(cse_var_1, dtype=float32)))) + (1.5f32*((cse_var_1 - @tir.floor(cse_var_1, dtype=float32))*(cse_var_1 - @tir.floor(cse_var_1, dtype=float32))))) - (-0.75f32*(cse_var_1 - @tir.floor(cse_var_1, dtype=float32))))), 0f32)), 0f32) + select((((((ax2 == max(min((cast(int32, @tir.floor(cse_var_1, dtype=float32)) + 2), 31), 0)) && (cse_var_2 == max(min((cast(int32, @tir.floor(cse_var_3, dtype=float32)) - 1), 31), 0))) || ((ax2 == max(min((cast(int32, @tir.floor(cse_var_1, dtype=float32)) + 2), 31), 0)) && (cse_var_2 == max(min(cast(int32, @tir.floor(cse_var_3, dtype=float32)), 31), 0)))) || ((ax2 == max(min((cast(int32, @tir.floor(cse_var_1, dtype=float32)) + 2), 31), 0)) && (cse_var_2 == max(min((cast(int32, @tir.floor(cse_var_3, dtype=float32)) + 1), 31), 0)))) || ((ax2 == max(min((cast(int32, @tir.floor(cse_var_1, dtype=float32)) + 2), 31), 0)) && (cse_var_2 == max(min((cast(int32, @tir.floor(cse_var_3, dtype=float32)) + 2), 31), 0)))), ((select((((cse_var_2 == max(min((cast(int32, @tir.floor(cse_var_3, dtype=float32)) - 1), 31), 0)) || (cse_var_2 == max(min(cast(int32, @tir.floor(cse_var_3, dtype=float32)), 31), 0))) || (cse_var_2 == max(min((cast(int32, @tir.floor(cse_var_3, dtype=float32)) + 1), 31), 0))), (select(((cse_var_2 == max(min((cast(int32, @tir.floor(cse_var_3, dtype=float32)) - 1), 31), 0)) || (cse_var_2 == max(min(cast(int32, @tir.floor(cse_var_3, dtype=float32)), 31), 0))), (select((cse_var_2 == max(min((cast(int32, @tir.floor(cse_var_3, dtype=float32)) - 1), 31), 0)), (-0.75f32*(((((cse_var_3 - @tir.floor(cse_var_3, dtype=float32))*(cse_var_3 - @tir.floor(cse_var_3, dtype=float32)))*(cse_var_3 - @tir.floor(cse_var_3, dtype=float32))) - (2f32*((cse_var_3 - @tir.floor(cse_var_3, dtype=float32))*(cse_var_3 - @tir.floor(cse_var_3, dtype=float32))))) + (cse_var_3 - @tir.floor(cse_var_3, dtype=float32)))), 0f32) + select((cse_var_2 == max(min(cast(int32, @tir.floor(cse_var_3, dtype=float32)), 31), 0)), (((1.25f32*(((cse_var_3 - @tir.floor(cse_var_3, dtype=float32))*(cse_var_3 - @tir.floor(cse_var_3, dtype=float32)))*(cse_var_3 - @tir.floor(cse_var_3, dtype=float32)))) - (2.25f32*((cse_var_3 - @tir.floor(cse_var_3, dtype=float32))*(cse_var_3 - @tir.floor(cse_var_3, dtype=float32))))) + 1f32), 0f32)), 0f32) + select((cse_var_2 == max(min((cast(int32, @tir.floor(cse_var_3, dtype=float32)) + 1), 31), 0)), (((-1.25f32*(((cse_var_3 - @tir.floor(cse_var_3, dtype=float32))*(cse_var_3 - @tir.floor(cse_var_3, dtype=float32)))*(cse_var_3 - @tir.floor(cse_var_3, dtype=float32)))) + (1.5f32*((cse_var_3 - @tir.floor(cse_var_3, dtype=float32))*(cse_var_3 - @tir.floor(cse_var_3, dtype=float32))))) - (-0.75f32*(cse_var_3 - @tir.floor(cse_var_3, dtype=float32)))), 0f32)), 0f32) + select((cse_var_2 == max(min((cast(int32, @tir.floor(cse_var_3, dtype=float32)) + 2), 31), 0)), ((0.75f32*(((cse_var_3 - @tir.floor(cse_var_3, dtype=float32))*(cse_var_3 - @tir.floor(cse_var_3, dtype=float32)))*(cse_var_3 - @tir.floor(cse_var_3, dtype=float32)))) + (-0.75f32*((cse_var_3 - @tir.floor(cse_var_3, dtype=float32))*(cse_var_3 - @tir.floor(cse_var_3, dtype=float32))))), 0f32))*((0.75f32*(((cse_var_1 - @tir.floor(cse_var_1, dtype=float32))*(cse_var_1 - @tir.floor(cse_var_1, dtype=float32)))*(cse_var_1 - @tir.floor(cse_var_1, dtype=float32)))) + (-0.75f32*((cse_var_1 - @tir.floor(cse_var_1, dtype=float32))*(cse_var_1 - @tir.floor(cse_var_1, dtype=float32)))))), 0f32))))
}
}
}
}
}
}
}
}
The IRs are pretty much identical, so it may be due to the change of lowering passes.
cc @Lunderberg @masahi
Triage
Expected behavior
The lowering time of the given case should be around 10 seconds.
Actual behavior
The lowering time is more than 550 seconds.
Environment
Any environment with commit commit 101e3a4 (#13217) or later.
Steps to reproduce
The script:
Here are also the lowered IR without and with this commit:
Without this commit:
With this commit:
The IRs are pretty much identical, so it may be due to the change of lowering passes.
cc @Lunderberg @masahi
Triage