diff --git a/testing/python/language/test_tilelang_language_frontend_v2.py b/testing/python/language/test_tilelang_language_frontend_v2.py index 41657dd73..de37695e7 100644 --- a/testing/python/language/test_tilelang_language_frontend_v2.py +++ b/testing/python/language/test_tilelang_language_frontend_v2.py @@ -395,5 +395,31 @@ def prim_call_macro(): pass +def frame_inside_macro(): + + @tilelang.jit + def get_sample_kernel(): + + @T.macro + def transform(x): + return x + 1 + + @T.prim_func + def sample_kernel( + num_blocks: T.int32, + idx_out: T.Tensor[(32,), T.int32], + ): + with T.Kernel(num_blocks, threads=32) as block_idx: # noqa: F841 + fragment = T.alloc_fragment(32, 'int32') + T.copy(idx_out, fragment) + + for i in T.Parallel(32): + idx_out[i] = transform(fragment[i]) + + return sample_kernel + + kernel = get_sample_kernel() # noqa: F841 + + if __name__ == '__main__': tilelang.testing.main() diff --git a/tilelang/language/v2/builder.py b/tilelang/language/v2/builder.py index e693f8504..1004a54ac 100644 --- a/tilelang/language/v2/builder.py +++ b/tilelang/language/v2/builder.py @@ -80,6 +80,10 @@ class MacroFrame(Frame): ... +class ExitedMacroFrame(Frame): + ... + + class BoolOpFrame(Frame): ... @@ -164,8 +168,22 @@ def macro(self, name=None, annotations=None): save = self.name_inside_frame, self.arg_annotations self.name_inside_frame = {} self.arg_annotations = annotations or {} - with self.with_frame(MacroFrame()): - yield + pos = len(self.frames) + # here we add a ExitedMacroFrame to preserve the frame stack inside macro + # because macro may bind some variable, and return it + # + # ```py + # @T.macro + # def foo(x): + # y = x + 1 + # return y + # @T.prim_func + # def bar(): + # c = foo(1) # macro generates let y = x + 1 + # d = c # d = c should lay inside frame of `let y = x + 1` + self.frames.append(MacroFrame()) + yield + self.frames[pos] = ExitedMacroFrame() self.name_inside_frame, self.arg_annotations = save def get(self):