diff --git a/python/tvm/script/parser/core/utils.py b/python/tvm/script/parser/core/utils.py index 85190b96d9ce..fc8a928e05d5 100644 --- a/python/tvm/script/parser/core/utils.py +++ b/python/tvm/script/parser/core/utils.py @@ -89,6 +89,88 @@ def inspect_class_capture(cls: type) -> dict[str, Any]: return result +def _collect_annotation_names(source_obj: type | Callable) -> set[str]: + """Parse source AST to find names used in function annotations. + + Returns the set of ``ast.Name`` identifiers found inside argument + annotations and return annotations of any function definitions in + *source_obj*. + """ + import ast + import textwrap + + try: + source = textwrap.dedent(inspect.getsource(source_obj)) + tree = ast.parse(source) + except (OSError, TypeError): + return set() + + names: set[str] = set() + for node in ast.walk(tree): + if isinstance(node, ast.FunctionDef | ast.AsyncFunctionDef): + for arg in node.args.args + node.args.posonlyargs + node.args.kwonlyargs: + if arg.annotation: + for n in ast.walk(arg.annotation): + if isinstance(n, ast.Name): + names.add(n.id) + if node.returns: + for n in ast.walk(node.returns): + if isinstance(n, ast.Name): + names.add(n.id) + return names + + +def _has_string_annotations(source_obj: type | Callable) -> bool: + """Check if *source_obj* has stringified annotations (PEP 563).""" + if inspect.isclass(source_obj): + return any( + isinstance(a, str) + for v in source_obj.__dict__.values() + if inspect.isfunction(v) + for a in v.__annotations__.values() + ) + return any(isinstance(a, str) for a in getattr(source_obj, "__annotations__", {}).values()) + + +def _get_enclosing_scope_names(qualname: str) -> set[str]: + """Extract lexically enclosing scope names from ``__qualname__``. + + For ``outer..inner..func`` this returns ``{"outer", "inner"}``. + """ + parts = qualname.split(".") + return {p for p in parts[:-1] if p != ""} + + +def resolve_closure_vars( + source_obj: type | Callable, extra_vars: dict[str, Any], outer_stack: list +) -> None: + """Resolve closure variables hidden by PEP 563. + + With ``from __future__ import annotations``, variables used only in + annotations are not captured in ``__closure__``. This function parses + the source AST to find names used in function annotations, then looks + them up in lexically enclosing scope frames identified via + ``__qualname__``. + + Only triggered when annotations are actually strings (PEP 563 active). + Only annotation-referenced names are added, and only from enclosing + scopes — not from arbitrary caller frames. + + Works for both classes (``@I.ir_module``) and functions (``@T.prim_func``). + """ + if not _has_string_annotations(source_obj): + return + ann_names = _collect_annotation_names(source_obj) + enclosing = _get_enclosing_scope_names(source_obj.__qualname__) + for name in ann_names: + if name not in extra_vars: + for frame_info in outer_stack[1:]: + if frame_info.frame.f_code.co_name in enclosing: + if name in frame_info.frame.f_locals: + extra_vars[name] = frame_info.frame.f_locals[name] + break + + def is_defined_in_class(frames: list[FrameType], obj: Any) -> bool: """Check whether a object is defined in a class scope. diff --git a/python/tvm/script/parser/ir/entry.py b/python/tvm/script/parser/ir/entry.py index 8f7a5be663cf..b0685e3db05f 100644 --- a/python/tvm/script/parser/ir/entry.py +++ b/python/tvm/script/parser/ir/entry.py @@ -46,6 +46,9 @@ def ir_module(mod: type | None = None, check_well_formed: bool = True) -> IRModu The parsed ir module. """ + # Capture stack outside wrapper (wrapper adds to the stack) + outer_stack = inspect.stack() + def decorator_wrapper(mod): if not inspect.isclass(mod): raise TypeError(f"Expect a class, but got: {mod}") @@ -53,7 +56,10 @@ def decorator_wrapper(mod): # Check BasePyModule inheritance base_py_module_inherited = any(base.__name__ == "BasePyModule" for base in mod.__bases__) - m = parse(mod, utils.inspect_class_capture(mod), check_well_formed=check_well_formed) + extra_vars = utils.inspect_class_capture(mod) + # Resolve closure variables hidden by PEP 563 (annotation-only names) + utils.resolve_closure_vars(mod, extra_vars, outer_stack) + m = parse(mod, extra_vars, check_well_formed=check_well_formed) if base_py_module_inherited: # Lazy import: tvm.relax cannot be imported at module level in tvm.script.parser diff --git a/python/tvm/script/parser/tir/entry.py b/python/tvm/script/parser/tir/entry.py index da09851e6757..d0486b0d9fbc 100644 --- a/python/tvm/script/parser/tir/entry.py +++ b/python/tvm/script/parser/tir/entry.py @@ -63,7 +63,9 @@ def decorator_wrapper(func): raise TypeError(f"Expect a function, but got: {func}") if utils.is_defined_in_class(outer_stack, func): return func - f = parse(func, utils.inspect_function_capture(func), check_well_formed=check_well_formed) + extra_vars = utils.inspect_function_capture(func) + utils.resolve_closure_vars(func, extra_vars, outer_stack) + f = parse(func, extra_vars, check_well_formed=check_well_formed) setattr(f, "__name__", func.__name__) return f diff --git a/tests/python/tvmscript/test_tvmscript_pep563_closure.py b/tests/python/tvmscript/test_tvmscript_pep563_closure.py new file mode 100644 index 000000000000..a5d26d7f1628 --- /dev/null +++ b/tests/python/tvmscript/test_tvmscript_pep563_closure.py @@ -0,0 +1,158 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""Test TVMScript with PEP 563 (from __future__ import annotations). + +IMPORTANT: The `from __future__ import annotations` import below is the +test condition itself, because we need to test compatibility with it. +""" + +from __future__ import annotations + +import tvm +import tvm.testing +from tvm.script import ir as I +from tvm.script import tir as T + + +def _normalize(func): + """Strip the global_symbol so function names do not affect structural equality.""" + return func.with_attr("global_symbol", "") + + +def test_prim_func_closure_shape(): + """Closure variable used in Buffer shape annotation.""" + + def f(M=16): + @T.prim_func + def func(A: T.Buffer((M,), "float32")): + T.evaluate(0) + + return func + + @T.prim_func + def expected_16(A: T.Buffer((16,), "float32")): + T.evaluate(0) + + @T.prim_func + def expected_32(A: T.Buffer((32,), "float32")): + T.evaluate(0) + + tvm.ir.assert_structural_equal(_normalize(f(16)), _normalize(expected_16)) + tvm.ir.assert_structural_equal(_normalize(f(32)), _normalize(expected_32)) + + +def test_prim_func_closure_dtype(): + """Closure variable used as Buffer dtype.""" + + def f(dtype="float32"): + @T.prim_func + def func(A: T.Buffer((16,), dtype)): + T.evaluate(0) + + return func + + @T.prim_func + def expected_f32(A: T.Buffer((16,), "float32")): + T.evaluate(0) + + @T.prim_func + def expected_f16(A: T.Buffer((16,), "float16")): + T.evaluate(0) + + tvm.ir.assert_structural_equal(_normalize(f("float32")), _normalize(expected_f32)) + tvm.ir.assert_structural_equal(_normalize(f("float16")), _normalize(expected_f16)) + + +def test_prim_func_nested_closure(): + """Variables from enclosing scope active on the call stack (grandparent frame fallback). + + With PEP 563, closure-only variables are missing from __closure__ unless they + appear in the function body. The ChainMap fallback walks the live call stack, + so this works when the enclosing frames are still active (outer calls middle + which applies the decorator, keeping outer's frame alive on the stack). + """ + + def outer(M=16): + def middle(N=8): + @T.prim_func + def func(A: T.Buffer((M, N), "float32")): + T.evaluate(0) + + return func + + return middle() + + @T.prim_func + def expected_16_8(A: T.Buffer((16, 8), "float32")): + T.evaluate(0) + + @T.prim_func + def expected_32_8(A: T.Buffer((32, 8), "float32")): + T.evaluate(0) + + tvm.ir.assert_structural_equal(_normalize(outer(16)), _normalize(expected_16_8)) + tvm.ir.assert_structural_equal(_normalize(outer(32)), _normalize(expected_32_8)) + + +def test_ir_module_closure(): + """Closure variable in @I.ir_module class method.""" + + def f(M=16): + @I.ir_module + class Mod: + @T.prim_func + def main(A: T.Buffer((M,), "float32")): + T.evaluate(0) + + return Mod + + @T.prim_func + def expected_16(A: T.Buffer((16,), "float32")): + T.evaluate(0) + + @T.prim_func + def expected_32(A: T.Buffer((32,), "float32")): + T.evaluate(0) + + tvm.ir.assert_structural_equal(_normalize(f(16)["main"]), _normalize(expected_16)) + tvm.ir.assert_structural_equal(_normalize(f(32)["main"]), _normalize(expected_32)) + + +def test_mixed_closure_usage(): + """Closure var used in both annotation AND body -- regression check.""" + + def f(M=16): + @T.prim_func + def func(A: T.Buffer((M,), "float32")): + T.evaluate(M) + + return func + + @T.prim_func + def expected_16(A: T.Buffer((16,), "float32")): + T.evaluate(16) + + @T.prim_func + def expected_32(A: T.Buffer((32,), "float32")): + T.evaluate(32) + + tvm.ir.assert_structural_equal(_normalize(f(16)), _normalize(expected_16)) + tvm.ir.assert_structural_equal(_normalize(f(32)), _normalize(expected_32)) + + +if __name__ == "__main__": + tvm.testing.main()