Skip to content
82 changes: 82 additions & 0 deletions python/tvm/script/parser/core/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,88 @@ def inspect_class_capture(cls: type) -> dict[str, Any]:
return result


def _collect_annotation_names(source_obj: type | Callable) -> set[str]:
"""Parse source AST to find names used in function annotations.

Returns the set of ``ast.Name`` identifiers found inside argument
annotations and return annotations of any function definitions in
*source_obj*.
"""
import ast
import textwrap

try:
source = textwrap.dedent(inspect.getsource(source_obj))
tree = ast.parse(source)
except (OSError, TypeError):
return set()

names: set[str] = set()
for node in ast.walk(tree):
if isinstance(node, ast.FunctionDef | ast.AsyncFunctionDef):
for arg in node.args.args + node.args.posonlyargs + node.args.kwonlyargs:
if arg.annotation:
for n in ast.walk(arg.annotation):
if isinstance(n, ast.Name):
names.add(n.id)
if node.returns:
for n in ast.walk(node.returns):
if isinstance(n, ast.Name):
names.add(n.id)
return names


def _has_string_annotations(source_obj: type | Callable) -> bool:
"""Check if *source_obj* has stringified annotations (PEP 563)."""
if inspect.isclass(source_obj):
return any(
isinstance(a, str)
for v in source_obj.__dict__.values()
if inspect.isfunction(v)
for a in v.__annotations__.values()
)
return any(isinstance(a, str) for a in getattr(source_obj, "__annotations__", {}).values())


def _get_enclosing_scope_names(qualname: str) -> set[str]:
"""Extract lexically enclosing scope names from ``__qualname__``.

For ``outer.<locals>.inner.<locals>.func`` this returns ``{"outer", "inner"}``.
"""
parts = qualname.split(".")
return {p for p in parts[:-1] if p != "<locals>"}


def resolve_closure_vars(
source_obj: type | Callable, extra_vars: dict[str, Any], outer_stack: list
) -> None:
"""Resolve closure variables hidden by PEP 563.

With ``from __future__ import annotations``, variables used only in
annotations are not captured in ``__closure__``. This function parses
the source AST to find names used in function annotations, then looks
them up in lexically enclosing scope frames identified via
``__qualname__``.

Only triggered when annotations are actually strings (PEP 563 active).
Only annotation-referenced names are added, and only from enclosing
scopes — not from arbitrary caller frames.

Works for both classes (``@I.ir_module``) and functions (``@T.prim_func``).
"""
if not _has_string_annotations(source_obj):
return
ann_names = _collect_annotation_names(source_obj)
enclosing = _get_enclosing_scope_names(source_obj.__qualname__)
for name in ann_names:
if name not in extra_vars:
for frame_info in outer_stack[1:]:
if frame_info.frame.f_code.co_name in enclosing:
if name in frame_info.frame.f_locals:
extra_vars[name] = frame_info.frame.f_locals[name]
break


def is_defined_in_class(frames: list[FrameType], obj: Any) -> bool:
"""Check whether a object is defined in a class scope.

Expand Down
8 changes: 7 additions & 1 deletion python/tvm/script/parser/ir/entry.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,14 +46,20 @@ def ir_module(mod: type | None = None, check_well_formed: bool = True) -> IRModu
The parsed ir module.
"""

# Capture stack outside wrapper (wrapper adds to the stack)
outer_stack = inspect.stack()

def decorator_wrapper(mod):
if not inspect.isclass(mod):
raise TypeError(f"Expect a class, but got: {mod}")

# Check BasePyModule inheritance
base_py_module_inherited = any(base.__name__ == "BasePyModule" for base in mod.__bases__)

m = parse(mod, utils.inspect_class_capture(mod), check_well_formed=check_well_formed)
extra_vars = utils.inspect_class_capture(mod)
# Resolve closure variables hidden by PEP 563 (annotation-only names)
utils.resolve_closure_vars(mod, extra_vars, outer_stack)
m = parse(mod, extra_vars, check_well_formed=check_well_formed)

if base_py_module_inherited:
# Lazy import: tvm.relax cannot be imported at module level in tvm.script.parser
Expand Down
4 changes: 3 additions & 1 deletion python/tvm/script/parser/tir/entry.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,9 @@ def decorator_wrapper(func):
raise TypeError(f"Expect a function, but got: {func}")
if utils.is_defined_in_class(outer_stack, func):
return func
f = parse(func, utils.inspect_function_capture(func), check_well_formed=check_well_formed)
extra_vars = utils.inspect_function_capture(func)
utils.resolve_closure_vars(func, extra_vars, outer_stack)
f = parse(func, extra_vars, check_well_formed=check_well_formed)
setattr(f, "__name__", func.__name__)
return f

Expand Down
158 changes: 158 additions & 0 deletions tests/python/tvmscript/test_tvmscript_pep563_closure.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,158 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
"""Test TVMScript with PEP 563 (from __future__ import annotations).

IMPORTANT: The `from __future__ import annotations` import below is the
test condition itself, because we need to test compatibility with it.
"""

from __future__ import annotations

import tvm
import tvm.testing
from tvm.script import ir as I
from tvm.script import tir as T


def _normalize(func):
"""Strip the global_symbol so function names do not affect structural equality."""
return func.with_attr("global_symbol", "")


def test_prim_func_closure_shape():
"""Closure variable used in Buffer shape annotation."""

def f(M=16):
@T.prim_func
def func(A: T.Buffer((M,), "float32")):
T.evaluate(0)

return func

@T.prim_func
def expected_16(A: T.Buffer((16,), "float32")):
T.evaluate(0)

@T.prim_func
def expected_32(A: T.Buffer((32,), "float32")):
T.evaluate(0)

tvm.ir.assert_structural_equal(_normalize(f(16)), _normalize(expected_16))
tvm.ir.assert_structural_equal(_normalize(f(32)), _normalize(expected_32))


def test_prim_func_closure_dtype():
"""Closure variable used as Buffer dtype."""

def f(dtype="float32"):
@T.prim_func
def func(A: T.Buffer((16,), dtype)):
T.evaluate(0)

return func

@T.prim_func
def expected_f32(A: T.Buffer((16,), "float32")):
T.evaluate(0)

@T.prim_func
def expected_f16(A: T.Buffer((16,), "float16")):
T.evaluate(0)

tvm.ir.assert_structural_equal(_normalize(f("float32")), _normalize(expected_f32))
tvm.ir.assert_structural_equal(_normalize(f("float16")), _normalize(expected_f16))


def test_prim_func_nested_closure():
"""Variables from enclosing scope active on the call stack (grandparent frame fallback).

With PEP 563, closure-only variables are missing from __closure__ unless they
appear in the function body. The ChainMap fallback walks the live call stack,
so this works when the enclosing frames are still active (outer calls middle
which applies the decorator, keeping outer's frame alive on the stack).
"""

def outer(M=16):
def middle(N=8):
@T.prim_func
def func(A: T.Buffer((M, N), "float32")):
T.evaluate(0)

return func

return middle()

@T.prim_func
def expected_16_8(A: T.Buffer((16, 8), "float32")):
T.evaluate(0)

@T.prim_func
def expected_32_8(A: T.Buffer((32, 8), "float32")):
T.evaluate(0)

tvm.ir.assert_structural_equal(_normalize(outer(16)), _normalize(expected_16_8))
tvm.ir.assert_structural_equal(_normalize(outer(32)), _normalize(expected_32_8))


def test_ir_module_closure():
"""Closure variable in @I.ir_module class method."""

def f(M=16):
@I.ir_module
class Mod:
@T.prim_func
def main(A: T.Buffer((M,), "float32")):
T.evaluate(0)

return Mod

@T.prim_func
def expected_16(A: T.Buffer((16,), "float32")):
T.evaluate(0)

@T.prim_func
def expected_32(A: T.Buffer((32,), "float32")):
T.evaluate(0)

tvm.ir.assert_structural_equal(_normalize(f(16)["main"]), _normalize(expected_16))
tvm.ir.assert_structural_equal(_normalize(f(32)["main"]), _normalize(expected_32))


def test_mixed_closure_usage():
"""Closure var used in both annotation AND body -- regression check."""

def f(M=16):
@T.prim_func
def func(A: T.Buffer((M,), "float32")):
T.evaluate(M)

return func

@T.prim_func
def expected_16(A: T.Buffer((16,), "float32")):
T.evaluate(16)

@T.prim_func
def expected_32(A: T.Buffer((32,), "float32")):
T.evaluate(32)

tvm.ir.assert_structural_equal(_normalize(f(16)), _normalize(expected_16))
tvm.ir.assert_structural_equal(_normalize(f(32)), _normalize(expected_32))


if __name__ == "__main__":
tvm.testing.main()
Loading