From 0ae9cd3a594621bf3c4c1043a6a1644bfe81924f Mon Sep 17 00:00:00 2001 From: tqchen Date: Tue, 5 Aug 2025 13:20:32 -0400 Subject: [PATCH] [REFACTOR] Phase out getattr based attribute handling This PR phases out getattar based attribute handling as they are slower and introduces extra code path. This does mean that if an Object is not explicitly registered in python side, we will no longer be able to access the field by name. Likely this is also desirable as we would like to enable faster use that updates the python end and do not rely on these behavior. --- docs/reference/api/python/relax/op.rst | 1 + docs/reference/api/python/tir/transform.rst | 1 + ffi/src/ffi/extra/serialization.cc | 13 +- include/tvm/relax/attrs/op.h | 30 ++-- include/tvm/script/printer/doc.h | 5 +- python/tvm/arith/iter_affine_map.py | 6 + python/tvm/contrib/msc/core/ir/graph.py | 4 +- python/tvm/ffi/__init__.py | 1 + python/tvm/ffi/cython/function.pxi | 2 + python/tvm/ffi/cython/object.pxi | 72 ++++---- python/tvm/ffi/serialization.py | 67 ++++++++ python/tvm/ir/attrs.py | 14 +- python/tvm/relax/dpl/pattern.py | 2 +- python/tvm/relax/expr.py | 5 + python/tvm/relax/op/_op_gradient.py | 4 +- python/tvm/relax/op/manipulate.py | 1 + python/tvm/relax/op/op_attrs.py | 155 ++++++++++++++++++ python/tvm/runtime/_ffi_node_api.py | 8 - python/tvm/runtime/object.py | 11 -- python/tvm/script/printer/doc.py | 35 +--- python/tvm/te/tensor.py | 20 --- python/tvm/testing/__init__.py | 1 + python/tvm/testing/attrs.py | 28 ++++ python/tvm/tir/transform/transform.py | 42 +++++ src/node/reflection.cc | 74 +-------- src/relax/ir/emit_te.h | 2 +- src/tir/transforms/hoist_expression.cc | 4 +- tests/python/ffi/test_container.py | 7 + tests/python/ir/test_ir_attrs.py | 4 +- .../test_transform_legalize_ops_manipulate.py | 7 +- tests/python/runtime/test_runtime_rpc.py | 1 - 31 files changed, 406 insertions(+), 221 deletions(-) create mode 100644 python/tvm/ffi/serialization.py create mode 100644 python/tvm/testing/attrs.py diff --git a/docs/reference/api/python/relax/op.rst b/docs/reference/api/python/relax/op.rst index 21f638442a84..922af768f50f 100644 --- a/docs/reference/api/python/relax/op.rst +++ b/docs/reference/api/python/relax/op.rst @@ -70,3 +70,4 @@ tvm.relax.op.op_attrs ********************* .. automodule:: tvm.relax.op.op_attrs :members: + :exclude-members: Attrs diff --git a/docs/reference/api/python/tir/transform.rst b/docs/reference/api/python/tir/transform.rst index 8ce641b6d3f6..29f1bcbbf036 100644 --- a/docs/reference/api/python/tir/transform.rst +++ b/docs/reference/api/python/tir/transform.rst @@ -20,4 +20,5 @@ tvm.tir.transform ----------------- .. automodule:: tvm.tir.transform :members: + :exclude-members: Attrs :imported-members: diff --git a/ffi/src/ffi/extra/serialization.cc b/ffi/src/ffi/extra/serialization.cc index 8d9df03361c2..ea9a96b696ec 100644 --- a/ffi/src/ffi/extra/serialization.cc +++ b/ffi/src/ffi/extra/serialization.cc @@ -408,9 +408,20 @@ class ObjectGraphDeserializer { Any FromJSONGraph(const json::Value& value) { return ObjectGraphDeserializer::Deserialize(value); } +// string version of the api +Any FromJSONGraphString(const String& value) { return FromJSONGraph(json::Parse(value)); } + +String ToJSONGraphString(const Any& value, const Any& metadata) { + return json::Stringify(ToJSONGraph(value, metadata)); +} + TVM_FFI_STATIC_INIT_BLOCK({ namespace refl = tvm::ffi::reflection; - refl::GlobalDef().def("ffi.ToJSONGraph", ToJSONGraph).def("ffi.FromJSONGraph", FromJSONGraph); + refl::GlobalDef() + .def("ffi.ToJSONGraph", ToJSONGraph) + .def("ffi.ToJSONGraphString", ToJSONGraphString) + .def("ffi.FromJSONGraph", FromJSONGraph) + .def("ffi.FromJSONGraphString", FromJSONGraphString); refl::EnsureTypeAttrColumn("__data_to_json__"); refl::EnsureTypeAttrColumn("__data_from_json__"); }); diff --git a/include/tvm/relax/attrs/op.h b/include/tvm/relax/attrs/op.h index cce78e9fd615..337f8dc4cbc2 100644 --- a/include/tvm/relax/attrs/op.h +++ b/include/tvm/relax/attrs/op.h @@ -51,16 +51,19 @@ struct CallTIRWithGradAttrs : public AttrsNodeReflAdapter /*! \brief Attributes used in call_tir_inplace */ struct CallTIRInplaceAttrs : public AttrsNodeReflAdapter { + /*! + * \brief Indices that describe which input corresponds to which output. + * + * If the `i`th member has the value `k` >= 0, then that means that input `k` should be used to + * store the `i`th output. If an element has the value -1, that means a new tensor should be + * allocated for that output. + */ Array inplace_indices; static void RegisterReflection() { namespace refl = tvm::ffi::reflection; - refl::ObjectDef().def_ro( - "inplace_indices", &CallTIRInplaceAttrs::inplace_indices, - "Indices that describe which input corresponds to which output. If the `i`th member " - "has the value `k` >= 0, then that means that input `k` should be used to store the " - "`i`th output. If an element has the value -1, that means a new tensor should be " - "allocated for that output."); + refl::ObjectDef().def_ro("inplace_indices", + &CallTIRInplaceAttrs::inplace_indices); } static constexpr const char* _type_key = "relax.attrs.CallTIRInplaceAttrs"; @@ -69,16 +72,19 @@ struct CallTIRInplaceAttrs : public AttrsNodeReflAdapter { /*! \brief Attributes used in call_inplace_packed */ struct CallInplacePackedAttrs : public AttrsNodeReflAdapter { + /*! + * \brief Indices that describe which input corresponds to which output. + * + * If the `i`th member has the value `k` >= 0, then that means that input `k` should be used to + * store the `i`th output. If an element has the value -1, that means the output will be newly + * allocated. + */ Array inplace_indices; static void RegisterReflection() { namespace refl = tvm::ffi::reflection; - refl::ObjectDef().def_ro( - "inplace_indices", &CallInplacePackedAttrs::inplace_indices, - "Indices that describe which input corresponds to which output. If the `i`th member " - "has the value `k` >= 0, then that means that input `k` should be used to store the " - "`i`th output. If an element has the value -1, that means the output will be newly " - "allocated."); + refl::ObjectDef().def_ro("inplace_indices", + &CallInplacePackedAttrs::inplace_indices); } static constexpr const char* _type_key = "relax.attrs.CallInplacePackedAttrs"; diff --git a/include/tvm/script/printer/doc.h b/include/tvm/script/printer/doc.h index b19bcab4c3ef..de3fb0bbad2c 100644 --- a/include/tvm/script/printer/doc.h +++ b/include/tvm/script/printer/doc.h @@ -65,10 +65,11 @@ class DocNode : public Object { static void RegisterReflection() { namespace refl = tvm::ffi::reflection; - refl::ObjectDef().def_ro("source_paths", &DocNode::source_paths); + refl::ObjectDef().def_rw("source_paths", &DocNode::source_paths); } static constexpr const char* _type_key = "script.printer.Doc"; + static constexpr bool _type_mutable = true; TVM_DECLARE_BASE_OBJECT_INFO(DocNode, Object); @@ -174,7 +175,7 @@ class StmtDocNode : public DocNode { static void RegisterReflection() { namespace refl = tvm::ffi::reflection; - refl::ObjectDef().def_ro("comment", &StmtDocNode::comment); + refl::ObjectDef().def_rw("comment", &StmtDocNode::comment); } static constexpr const char* _type_key = "script.printer.StmtDoc"; diff --git a/python/tvm/arith/iter_affine_map.py b/python/tvm/arith/iter_affine_map.py index dbb4087f325f..328bb052b87f 100644 --- a/python/tvm/arith/iter_affine_map.py +++ b/python/tvm/arith/iter_affine_map.py @@ -22,6 +22,7 @@ from . import _ffi_api +@tvm.ffi.register_object("arith.IterMapExpr") class IterMapExpr(PrimExpr): """Base class of all IterMap expressions.""" @@ -89,6 +90,11 @@ def __init__(self, args, base): self.__init_handle_by_constructor__(_ffi_api.IterSumExpr, args, base) +@tvm.ffi.register_object("arith.IterMapResult") +class IterMapResult(Object): + """Result of iter map detection.""" + + class IterMapLevel(IntEnum): """Possible kinds of iter mapping check level.""" diff --git a/python/tvm/contrib/msc/core/ir/graph.py b/python/tvm/contrib/msc/core/ir/graph.py index 9aa5bde93380..7bd88df5f6f4 100644 --- a/python/tvm/contrib/msc/core/ir/graph.py +++ b/python/tvm/contrib/msc/core/ir/graph.py @@ -194,6 +194,7 @@ def ndim(self) -> int: return len(self.shape) +@tvm.ffi.register_object("msc.core.BaseJoint") class BaseJoint(Object): """Base class of all MSC Nodes.""" @@ -561,6 +562,7 @@ def has_attr(self, key: str) -> bool: return bool(_ffi_api.WeightJointHasAttr(self, key)) +@tvm.ffi.register_object("msc.core.BaseGraph") class BaseGraph(Object): """Base class of all MSC Graphs.""" @@ -955,7 +957,7 @@ def visualize(self, path: Optional[str] = None) -> str: @tvm.ffi.register_object("msc.core.WeightGraph") -class WeightGraph(Object): +class WeightGraph(BaseGraph): """The WeightGraph Parameters diff --git a/python/tvm/ffi/__init__.py b/python/tvm/ffi/__init__.py index b507064e34d9..43a20e751c29 100644 --- a/python/tvm/ffi/__init__.py +++ b/python/tvm/ffi/__init__.py @@ -30,6 +30,7 @@ from .ndarray import cpu, cuda, rocm, opencl, metal, vpi, vulkan, ext_dev, hexagon, webgpu from .ndarray import from_dlpack, NDArray, Shape from .container import Array, Map +from . import serialization from . import testing diff --git a/python/tvm/ffi/cython/function.pxi b/python/tvm/ffi/cython/function.pxi index cbff3fecf135..8c9df19642b0 100644 --- a/python/tvm/ffi/cython/function.pxi +++ b/python/tvm/ffi/cython/function.pxi @@ -426,3 +426,5 @@ def _convert_to_ffi_func(object pyfunc): _STR_CONSTRUCTOR = _get_global_func("ffi.String", False) _BYTES_CONSTRUCTOR = _get_global_func("ffi.Bytes", False) +_OBJECT_FROM_JSON_GRAPH_STR = _get_global_func("ffi.FromJSONGraphString", True) +_OBJECT_TO_JSON_GRAPH_STR = _get_global_func("ffi.ToJSONGraphString", True) diff --git a/python/tvm/ffi/cython/object.pxi b/python/tvm/ffi/cython/object.pxi index 4efedf35d8f4..7df5f7a19aff 100644 --- a/python/tvm/ffi/cython/object.pxi +++ b/python/tvm/ffi/cython/object.pxi @@ -14,10 +14,12 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. +import warnings _CLASS_OBJECT = None _FUNC_CONVERT_TO_OBJECT = None + def _set_class_object(cls): global _CLASS_OBJECT _CLASS_OBJECT = cls @@ -32,31 +34,15 @@ def __object_repr__(obj): return type(obj).__name__ + "(" + obj.__ctypes_handle__().value + ")" -def __object_save_json__(obj): - """Object repr function that can be overridden by assigning to it""" - raise NotImplementedError("JSON serialization depends on downstream init") - - -def __object_load_json__(json_str): - """Object repr function that can be overridden by assigning to it""" - raise NotImplementedError("JSON serialization depends on downstream init") - - -def __object_dir__(obj): - """Object dir function that can be overridden by assigning to it""" - return [] - - -def __object_getattr__(obj, name): - """Object getattr function that can be overridden by assigning to it""" - raise AttributeError() - - def _new_object(cls): """Helper function for pickle""" return cls.__new__(cls) +_OBJECT_FROM_JSON_GRAPH_STR = None +_OBJECT_TO_JSON_GRAPH_STR = None + + class ObjectGeneric: """Base class for all classes that can be converted to object.""" @@ -107,34 +93,24 @@ cdef class Object: return (_new_object, (cls,), self.__getstate__()) def __getstate__(self): + if _OBJECT_TO_JSON_GRAPH_STR is None: + raise RuntimeError("ffi.ToJSONGraphString is not registered, make sure build project with extra API") if not self.__chandle__() == 0: # need to explicit convert to str in case String # returned and triggered another infinite recursion in get state - return {"handle": str(__object_save_json__(self))} + return {"handle": str(_OBJECT_TO_JSON_GRAPH_STR(self, None))} return {"handle": None} def __setstate__(self, state): # pylint: disable=assigning-non-slot, assignment-from-no-return + if _OBJECT_FROM_JSON_GRAPH_STR is None: + raise RuntimeError("ffi.FromJSONGraphString is not registered, make sure build project with extra API") handle = state["handle"] if handle is not None: - self.__init_handle_by_constructor__(__object_load_json__, handle) + self.__init_handle_by_constructor__(_OBJECT_FROM_JSON_GRAPH_STR, handle) else: self.chandle = NULL - def __getattr__(self, name): - if self.chandle == NULL: - raise AttributeError(f"{type(self)} has no attribute {name}") - try: - return __object_getattr__(self, name) - except AttributeError: - raise AttributeError(f"{type(self)} has no attribute {name}") - - def __dir__(self): - # exception safety handling for chandle=None - if self.chandle == NULL: - return [] - return __object_dir__(self) - def __repr__(self): # exception safety handling for chandle=None if self.chandle == NULL: @@ -147,9 +123,6 @@ cdef class Object: def __ne__(self, other): return not self.__eq__(other) - def __init_handle_by_load_json__(self, json_str): - raise NotImplementedError("JSON serialization depends on downstream init") - def __init_handle_by_constructor__(self, fconstructor, *args): """Initialize the handle by calling constructor function. @@ -269,6 +242,15 @@ def _object_type_key_to_index(str type_key): return tidx return None +cdef inline str _type_index_to_key(int32_t tindex): + """get the type key of object class""" + cdef const TVMFFITypeInfo* info = TVMFFIGetTypeInfo(tindex) + cdef const TVMFFIByteArray* type_key + if info == NULL: + return "" + type_key = &(info.type_key) + return py_str(PyBytes_FromStringAndSize(type_key.data, type_key.size)) + cdef inline object make_ret_object(TVMFFIAny result): global OBJECT_TYPE @@ -284,10 +266,14 @@ cdef inline object make_ret_object(TVMFFIAny result): (obj).chandle = result.v_obj return cls.__from_tvm_ffi_object__(cls, obj) obj = cls.__new__(cls) - else: - obj = _CLASS_OBJECT.__new__(_CLASS_OBJECT) - else: - obj = _CLASS_OBJECT.__new__(_CLASS_OBJECT) + (obj).chandle = result.v_obj + return obj + + # object is not found in registered entry + # in this case we need to report an warning + type_key = _type_index_to_key(tindex) + warnings.warn(f"Returning type `{type_key}` which is not registered via register_object, fallback to Object") + obj = _CLASS_OBJECT.__new__(_CLASS_OBJECT) (obj).chandle = result.v_obj return obj diff --git a/python/tvm/ffi/serialization.py b/python/tvm/ffi/serialization.py new file mode 100644 index 000000000000..25d9bcefb828 --- /dev/null +++ b/python/tvm/ffi/serialization.py @@ -0,0 +1,67 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""Serialization related utilities to enable some object can be pickled""" + +from typing import Optional, Any +from . import _ffi_api + + +def to_json_graph_str(obj: Any, metadata: Optional[dict] = None): + """ + Dump an object to a JSON graph string. + + The JSON graph string is a string representation of of the object + graph includes the reference information of same objects, which can + be used for serialization and debugging. + + Parameters + ---------- + obj : Any + The object to save. + + metadata : Optional[dict], optional + Extra metadata to save into the json graph string. + + Returns + ------- + json_str : str + The JSON graph string. + """ + return _ffi_api.ToJSONGraphString(obj, metadata) + + +def from_json_graph_str(json_str: str): + """ + Load an object from a JSON graph string. + + The JSON graph string is a string representation of of the object + graph that also includes the reference information. + + Parameters + ---------- + json_str : str + The JSON graph string to load. + + Returns + ------- + obj : Any + The loaded object. + """ + return _ffi_api.FromJSONGraphString(json_str) + + +__all__ = ["from_json_graph_str", "to_json_graph_str"] diff --git a/python/tvm/ir/attrs.py b/python/tvm/ir/attrs.py index e7de1a9f909b..cab982f4e783 100644 --- a/python/tvm/ir/attrs.py +++ b/python/tvm/ir/attrs.py @@ -41,7 +41,7 @@ def get_int_tuple(self, key): ------- value: Tuple of int """ - return tuple(x if isinstance(x, int) else x.value for x in self.__getattr__(key)) + return tuple(x if isinstance(x, int) else x.value for x in getattr(self, key)) def get_int(self, key): """Get a python int value of a key @@ -54,7 +54,7 @@ def get_int(self, key): ------- value: int """ - return self.__getattr__(key) + return getattr(self, key) def get_str(self, key): """Get a python int value of a key @@ -67,10 +67,10 @@ def get_str(self, key): ------- value: int """ - return self.__getattr__(key) + return getattr(self, key) def __getitem__(self, item): - return self.__getattr__(item) + return getattr(self, item) @tvm.ffi.register_object("ir.DictAttrs") @@ -101,6 +101,12 @@ def get(self, key, default=None): def __contains__(self, k): return self._dict().__contains__(k) + def __getattr__(self, name): + try: + return self._dict().__getitem__(name) + except KeyError: + raise AttributeError(f"DictAttrs has no attribute {name}") + def items(self): """Get items from the map.""" return self._dict().items() diff --git a/python/tvm/relax/dpl/pattern.py b/python/tvm/relax/dpl/pattern.py index 633c2c6790da..eca885e03acb 100644 --- a/python/tvm/relax/dpl/pattern.py +++ b/python/tvm/relax/dpl/pattern.py @@ -326,7 +326,7 @@ def __init__(self, name_hint: str = ""): @register_df_node -class DataflowVarPattern(DFPattern): +class DataflowVarPattern(VarPattern): """A pattern for DataflowVar. Parameters diff --git a/python/tvm/relax/expr.py b/python/tvm/relax/expr.py index 9ddaf52e722c..ee9caf3a835b 100644 --- a/python/tvm/relax/expr.py +++ b/python/tvm/relax/expr.py @@ -1177,6 +1177,11 @@ def const( return Constant(value) +@tvm.ffi.register_object("relax.TEPlaceholderOp") +class TEPlaceholderOp(tvm.te.tensor.Operation): + """The placeholder op that represents a relax expression.""" + + def te_tensor( value: Expr, tir_var_map: Dict[tvm.tir.Var, tvm.tir.PrimExpr], name: str = "rxplaceholder" ): diff --git a/python/tvm/relax/op/_op_gradient.py b/python/tvm/relax/op/_op_gradient.py index 41eaa5de5008..fd80f1e31333 100644 --- a/python/tvm/relax/op/_op_gradient.py +++ b/python/tvm/relax/op/_op_gradient.py @@ -829,8 +829,8 @@ def cumsum_grad( The "reversed" cumsum along the same axis. Implemented by some tricks now. """ - axis = orig_call.attrs["axis"] - dtype = orig_call.attrs["dtype"] + axis = orig_call.attrs.axis + dtype = orig_call.attrs.dtype x_shape = _get_shape(orig_call.args[0]) if axis is not None: diff --git a/python/tvm/relax/op/manipulate.py b/python/tvm/relax/op/manipulate.py index 864eb3fec709..bb134f114855 100644 --- a/python/tvm/relax/op/manipulate.py +++ b/python/tvm/relax/op/manipulate.py @@ -624,6 +624,7 @@ def index_put( Examples -------- .. code-block:: python + # inputs data = torch.zeros(3, 3) indices = (torch.tensor([0, 2]), torch.tensor([1, 1])) diff --git a/python/tvm/relax/op/op_attrs.py b/python/tvm/relax/op/op_attrs.py index 3e0f87c48751..9c15cdd96613 100644 --- a/python/tvm/relax/op/op_attrs.py +++ b/python/tvm/relax/op/op_attrs.py @@ -202,3 +202,158 @@ class FlipAttrs(Attrs): @tvm.ffi.register_object("relax.attrs.PadAttrs") class PadAttrs(Attrs): """Attributes used in pad operator""" + + +@tvm.ffi.register_object("relax.attrs.MultinomialFromUniformAttrs") +class MultinomialFromUniformAttrs(Attrs): + """Attributes for multinomial_from_uniform operator""" + + +@tvm.ffi.register_object("relax.attrs.CallInplacePackedAttrs") +class CallInplacePackedAttrs(Attrs): + """Attributes used in call_inplace_packed operator""" + + +@tvm.ffi.register_object("relax.attrs.CallTIRInplaceAttrs") +class CallTIRInplaceAttrs(Attrs): + """Attributes used in call_tir_inplace operator""" + + +@tvm.ffi.register_object("relax.attrs.ToVDeviceAttrs") +class ToVDeviceAttrs(Attrs): + """Attributes used in to_vdevice operator""" + + +@tvm.ffi.register_object("relax.attrs.HintOnDeviceAttrs") +class HintOnDeviceAttrs(Attrs): + """Attributes used in hint_on_device operator""" + + +@tvm.ffi.register_object("relax.attrs.ScatterCollectiveAttrs") +class ScatterCollectiveAttrs(Attrs): + """Attributes used in scatter collective operators""" + + +@tvm.ffi.register_object("relax.attrs.AttentionAttrs") +class AttentionAttrs(Attrs): + """Attributes used in attention operator""" + + +@tvm.ffi.register_object("relax.attrs.Conv1DAttrs") +class Conv1DAttrs(Attrs): + """Attributes for nn.conv1d""" + + +@tvm.ffi.register_object("relax.attrs.Conv1DTransposeAttrs") +class Conv1DTransposeAttrs(Attrs): + """Attributes for nn.conv1d_transpose""" + + +@tvm.ffi.register_object("relax.attrs.Pool1DAttrs") +class Pool1DAttrs(Attrs): + """Attributes for nn.max_pool1d and nn.avg_pool1d""" + + +@tvm.ffi.register_object("relax.attrs.Pool3DAttrs") +class Pool3DAttrs(Attrs): + """Attributes for nn.max_pool3d and nn.avg_pool3d""" + + +@tvm.ffi.register_object("relax.attrs.AdaptivePool1DAttrs") +class AdaptivePool1DAttrs(Attrs): + """Attributes for 1d adaptive pool operator""" + + +@tvm.ffi.register_object("relax.attrs.AdaptivePool3DAttrs") +class AdaptivePool3DAttrs(Attrs): + """Attributes for 3d adaptive pool operator""" + + +@tvm.ffi.register_object("relax.attrs.LeakyReluAttrs") +class LeakyReluAttrs(Attrs): + """Attributes used in leaky_relu operator""" + + +@tvm.ffi.register_object("relax.attrs.SoftplusAttrs") +class SoftplusAttrs(Attrs): + """Attributes used in softplus operator""" + + +@tvm.ffi.register_object("relax.attrs.PReluAttrs") +class PReluAttrs(Attrs): + """Attributes used in prelu operator""" + + +@tvm.ffi.register_object("relax.attrs.PixelShuffleAttrs") +class PixelShuffleAttrs(Attrs): + """Attributes used in pixel_shuffle operator""" + + +@tvm.ffi.register_object("relax.attrs.GroupNormAttrs") +class GroupNormAttrs(Attrs): + """Attributes used in group_norm operator""" + + +@tvm.ffi.register_object("relax.attrs.RMSNormAttrs") +class RMSNormAttrs(Attrs): + """Attributes used in rms_norm operator""" + + +@tvm.ffi.register_object("relax.attrs.NLLLossAttrs") +class NLLLossAttrs(Attrs): + """Attributes used in nll_loss operator""" + + +@tvm.ffi.register_object("relax.attrs.AllReduceAttrs") +class AllReduceAttrs(Attrs): + """Attributes used in allreduce operator""" + + +@tvm.ffi.register_object("relax.attrs.AllGatherAttrs") +class AllGatherAttrs(Attrs): + """Attributes used in allgather operator""" + + +@tvm.ffi.register_object("relax.attrs.WrapParamAttrs") +class WrapParamAttrs(Attrs): + """Attributes used in wrap_param operator""" + + +@tvm.ffi.register_object("relax.attrs.QuantizeAttrs") +class QuantizeAttrs(Attrs): + """Attributes used in quantize/dequantize operators""" + + +@tvm.ffi.register_object("relax.attrs.GatherElementsAttrs") +class GatherElementsAttrs(Attrs): + """Attributes for gather_elements operator""" + + +@tvm.ffi.register_object("relax.attrs.GatherNDAttrs") +class GatherNDAttrs(Attrs): + """Attributes for gather_nd operator""" + + +@tvm.ffi.register_object("relax.attrs.MeshgridAttrs") +class MeshgridAttrs(Attrs): + """Attributes for meshgrid operator""" + + +@tvm.ffi.register_object("relax.attrs.ScatterElementsAttrs") +class ScatterElementsAttrs(Attrs): + """Attributes for scatter_elements operator""" + + +@tvm.ffi.register_object("relax.attrs.ScatterNDAttrs") +class ScatterNDAttrs(Attrs): + """Attributes for scatter_nd operator""" + + +@tvm.ffi.register_object("relax.attrs.SliceScatterAttrs") +class SliceScatterAttrs(Attrs): + """Attributes for slice_scatter operator""" + + +@tvm.ffi.register_object("relax.attrs.OneHotAttrs") +class OneHotAttrs(Attrs): + """Attributes for one_hot operator""" diff --git a/python/tvm/runtime/_ffi_node_api.py b/python/tvm/runtime/_ffi_node_api.py index aef9ded9cc0d..4a0edd449c24 100644 --- a/python/tvm/runtime/_ffi_node_api.py +++ b/python/tvm/runtime/_ffi_node_api.py @@ -28,14 +28,6 @@ def AsRepr(obj): return type(obj).__name__ + "(" + obj.__ctypes_handle__().value + ")" -def NodeListAttrNames(obj): - return lambda x: 0 - - -def NodeGetAttr(obj, name): - raise AttributeError() - - def SaveJSON(obj): raise RuntimeError("Do not support object serialization in runtime only mode") diff --git a/python/tvm/runtime/object.py b/python/tvm/runtime/object.py index 688682d197c5..b2fcddc40ad6 100644 --- a/python/tvm/runtime/object.py +++ b/python/tvm/runtime/object.py @@ -22,17 +22,6 @@ from . import _ffi_node_api -def __object_dir__(obj): - class_names = dir(obj.__class__) - fnames = _ffi_node_api.NodeListAttrNames(obj) - size = fnames(-1) - return sorted([fnames(i) for i in range(size)] + class_names) - - tvm.ffi.core._set_class_object(Object) # override the default repr function for tvm.ffi.core.Object tvm.ffi.core.__object_repr__ = _ffi_node_api.AsRepr -tvm.ffi.core.__object_save_json__ = _ffi_node_api.SaveJSON -tvm.ffi.core.__object_load_json__ = _ffi_node_api.LoadJSON -tvm.ffi.core.__object_getattr__ = _ffi_node_api.NodeGetAttr -tvm.ffi.core.__object_dir__ = __object_dir__ diff --git a/python/tvm/script/printer/doc.py b/python/tvm/script/printer/doc.py index 02a67e916bc0..bf468b17ec18 100644 --- a/python/tvm/script/printer/doc.py +++ b/python/tvm/script/printer/doc.py @@ -26,25 +26,12 @@ from . import _ffi_api +@register_object("script.printer.Doc") class Doc(Object): """Base class of all Docs""" - @property - def source_paths(self) -> Sequence[ObjectPath]: - """ - The list of object paths of the source IR node. - - This is used to trace back to the IR node position where - this Doc is generated, in order to position the diagnostic - message. - """ - return self.__getattr__("source_paths") # pylint: disable=unnecessary-dunder-call - - @source_paths.setter - def source_paths(self, value): - return _ffi_api.DocSetSourcePaths(self, value) # type: ignore # pylint: disable=no-member - +@register_object("script.printer.ExprDoc") class ExprDoc(Doc): """Base class of all expression Docs""" @@ -114,26 +101,10 @@ def __iter__(self): raise RuntimeError(f"{self.__class__} cannot be used as iterable.") +@register_object("script.printer.StmtDoc") class StmtDoc(Doc): """Base class of statement doc""" - @property - def comment(self) -> Optional[str]: - """ - The comment of this doc. - - The actual position of the comment depends on the type of Doc - and also the DocPrinter implementation. It could be on the same - line as the statement, or the line above, or inside the statement - if it spans over multiple lines. - """ - # It has to call the dunder method to avoid infinite recursion - return self.__getattr__("comment") # pylint: disable=unnecessary-dunder-call - - @comment.setter - def comment(self, value): - return _ffi_api.StmtDocSetComment(self, value) # type: ignore # pylint: disable=no-member - @register_object("script.printer.StmtBlockDoc") class StmtBlockDoc(Doc): diff --git a/python/tvm/te/tensor.py b/python/tvm/te/tensor.py index 489ec38ba506..73b995a45e61 100644 --- a/python/tvm/te/tensor.py +++ b/python/tvm/te/tensor.py @@ -84,26 +84,6 @@ def ndim(self): """Dimension of the tensor.""" return len(self.shape) - @property - def axis(self): - """Axis of the tensor.""" - return self.__getattr__("axis") - - @property - def op(self): - """The corressponding :py:class:`Operation`.""" - return self.__getattr__("op") - - @property - def value_index(self): - """The output value index the tensor corresponds to.""" - return self.__getattr__("value_index") - - @property - def shape(self): - """The output shape of the tensor.""" - return self.__getattr__("shape") - @property def name(self): op = self.op diff --git a/python/tvm/testing/__init__.py b/python/tvm/testing/__init__.py index ea798242b462..620a66351d9c 100644 --- a/python/tvm/testing/__init__.py +++ b/python/tvm/testing/__init__.py @@ -43,3 +43,4 @@ ) from .runner import local_run, rpc_run from .utils import * +from .attrs import * diff --git a/python/tvm/testing/attrs.py b/python/tvm/testing/attrs.py new file mode 100644 index 000000000000..ea6f1b1af65c --- /dev/null +++ b/python/tvm/testing/attrs.py @@ -0,0 +1,28 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# pylint: disable=invalid-name, import-outside-toplevel, unused-variable +"""Testing utilities for attrs""" +from ..ir import Attrs +from ..ffi import register_object + + +@register_object("attrs.TestAttrs") +class TestAttrs(Attrs): + """Attrs used for testing purposes""" + + +__all__ = ["TestAttrs"] diff --git a/python/tvm/tir/transform/transform.py b/python/tvm/tir/transform/transform.py index 81ce63b7972f..93a182ca3bc2 100644 --- a/python/tvm/tir/transform/transform.py +++ b/python/tvm/tir/transform/transform.py @@ -23,6 +23,8 @@ from . import _ffi_api from . import function_pass as _fpass +from ... import ir as _ir +from ... import ffi as _ffi def Apply(ftransform): @@ -48,6 +50,11 @@ def _transform(func, mod, ctx): return _fpass.prim_func_pass(_transform, opt_level=0, name="Apply") # type: ignore +@_ffi.register_object("tir.transform.LoopPartitionConfig") +class LoopPartitionConfig(_ir.Attrs): + """Config for loop partition pass""" + + def LoopPartition(): """Inject virtual thread loops. @@ -87,6 +94,11 @@ def InjectVirtualThread(): return _ffi_api.InjectVirtualThread() # type: ignore +@_ffi.register_object("tir.transform.InjectDoubleBufferConfig") +class InjectDoubleBufferConfig(_ir.Attrs): + """Config for inject double buffer pass""" + + def InjectDoubleBuffer(): """Inject double buffer statements. @@ -149,6 +161,11 @@ def PointerValueTypeRewrite(): return _ffi_api.PointerValueTypeRewrite() # type: ignore +@_ffi.register_object("tir.transform.UnrollLoopConfig") +class UnrollLoopConfig(_ir.Attrs): + """Config for unroll loop pass""" + + def UnrollLoop(): """Unroll the constant loop marked by unroll. @@ -162,6 +179,11 @@ def UnrollLoop(): return _ffi_api.UnrollLoop() # type: ignore +@_ffi.register_object("tir.transform.ReduceBranchingThroughOvercomputeConfig") +class ReduceBranchingThroughOvercomputeConfig(_ir.Attrs): + """Config for reduce branching through overcompute pass""" + + def ReduceBranchingThroughOvercompute(): """Reduce branching by introducing overcompute @@ -173,6 +195,11 @@ def ReduceBranchingThroughOvercompute(): return _ffi_api.ReduceBranchingThroughOvercompute() # type: ignore +@_ffi.register_object("tir.transform.RemoveNoOpConfig") +class RemoveNoOpConfig(_ir.Attrs): + """Config for remove no op pass""" + + def RemoveNoOp(): """Remove No Op from the Stmt. @@ -277,6 +304,11 @@ def RewriteUnsafeSelect(): return _ffi_api.RewriteUnsafeSelect() # type: ignore +@_ffi.register_object("tir.transform.SimplifyConfig") +class SimplifyConfig(_ir.Attrs): + """Config for simplify pass""" + + def Simplify(): """Run arithmetic simplifications on the statements and expressions. @@ -607,6 +639,11 @@ def VerifyVTCMLimit(limit=None): return _ffi_api.VerifyVTCMLimit(limit) # type: ignore +@_ffi.register_object("tir.transform.HoistIfThenElseConfig") +class HoistIfThenElseConfig(_ir.Attrs): + """Config for hoist if then else pass""" + + # pylint: disable=no-else-return,inconsistent-return-statements def HoistIfThenElse(variant: Optional[str] = None): """Hoist loop-invariant IfThenElse nodes to outside the eligible loops. @@ -686,6 +723,11 @@ class HoistedLetBindings(enum.Flag): """ Enable all hoisting of let bindings """ +@_ffi.register_object("tir.transform.HoistExpressionConfig") +class HoistExpressionConfig(_ir.Attrs): + """Config for hoist expression pass""" + + def HoistExpression(): """Generalized verison of HoistIfThenElse. diff --git a/src/node/reflection.cc b/src/node/reflection.cc index 6db751a80f87..e666b434f8f5 100644 --- a/src/node/reflection.cc +++ b/src/node/reflection.cc @@ -33,75 +33,6 @@ using ffi::Any; using ffi::Function; using ffi::PackedArgs; -// Expose to FFI APIs. -void NodeGetAttr(ffi::PackedArgs args, ffi::Any* ret) { - Object* self = const_cast(args[0].cast()); - String field_name = args[1].cast(); - - bool success; - if (field_name == "type_key") { - *ret = self->GetTypeKey(); - success = true; - } else if (!self->IsInstance()) { - const TVMFFITypeInfo* type_info = TVMFFIGetTypeInfo(self->type_index()); - success = false; - // use new reflection mechanism - if (type_info->metadata != nullptr) { - ffi::reflection::ForEachFieldInfo(type_info, [&](const TVMFFIFieldInfo* field_info) { - if (field_name.compare(field_info->name) == 0) { - ffi::reflection::FieldGetter field_getter(field_info); - *ret = field_getter(self); - success = true; - } - }); - } - } else { - // specially handle dict attr - DictAttrsNode* dnode = static_cast(self); - auto it = dnode->dict.find(field_name); - if (it != dnode->dict.end()) { - success = true; - *ret = (*it).second; - } else { - success = false; - } - } - if (!success) { - TVM_FFI_THROW(AttributeError) << self->GetTypeKey() << " object has no attribute `" - << field_name << "`"; - } -} - -void NodeListAttrNames(ffi::PackedArgs args, ffi::Any* ret) { - Object* self = const_cast(args[0].cast()); - - std::vector names; - if (!self->IsInstance()) { - const TVMFFITypeInfo* type_info = TVMFFIGetTypeInfo(self->type_index()); - if (type_info->metadata != nullptr) { - // use new reflection mechanism - ffi::reflection::ForEachFieldInfo(type_info, [&](const TVMFFIFieldInfo* field_info) { - names.push_back(std::string(field_info->name.data, field_info->name.size)); - }); - } - } else { - // specially handle dict attr - DictAttrsNode* dnode = static_cast(self); - for (const auto& kv : dnode->dict) { - names.push_back(kv.first); - } - } - - *ret = ffi::Function::FromPacked([names](ffi::PackedArgs args, ffi::Any* rv) { - int64_t i = args[0].cast(); - if (i == -1) { - *rv = static_cast(names.size()); - } else { - *rv = names[i]; - } - }); -} - // API function to make node. // args format: // key1, value1, ..., key_n, value_n @@ -123,10 +54,7 @@ void MakeNode(const ffi::PackedArgs& args, ffi::Any* rv) { TVM_FFI_STATIC_INIT_BLOCK({ namespace refl = tvm::ffi::reflection; - refl::GlobalDef() - .def_packed("node.NodeGetAttr", NodeGetAttr) - .def_packed("node.NodeListAttrNames", NodeListAttrNames) - .def_packed("node.MakeNode", MakeNode); + refl::GlobalDef().def_packed("node.MakeNode", MakeNode); }); } // namespace tvm diff --git a/src/relax/ir/emit_te.h b/src/relax/ir/emit_te.h index bc4b90a37333..aa7cb9db538e 100644 --- a/src/relax/ir/emit_te.h +++ b/src/relax/ir/emit_te.h @@ -52,7 +52,7 @@ class RXPlaceholderOpNode : public te::PlaceholderOpNode { .def_ro("dtype", &RXPlaceholderOpNode::dtype); } - static constexpr const char* _type_key = "RXPlaceholderOp"; + static constexpr const char* _type_key = "relax.TEPlaceholderOp"; TVM_DECLARE_FINAL_OBJECT_INFO(RXPlaceholderOpNode, te::PlaceholderOpNode); }; diff --git a/src/tir/transforms/hoist_expression.cc b/src/tir/transforms/hoist_expression.cc index d89114c68abd..1548ea1da625 100644 --- a/src/tir/transforms/hoist_expression.cc +++ b/src/tir/transforms/hoist_expression.cc @@ -82,7 +82,7 @@ struct HoistExpressionConfigNode : public AttrsNodeReflAdapter(flag) & hoisted_let_bindings; } - static constexpr const char* _type_key = "tir.transforms.HoistExpressionConfig"; + static constexpr const char* _type_key = "tir.transform.HoistExpressionConfig"; TVM_DECLARE_FINAL_OBJECT_INFO(HoistExpressionConfigNode, Object); }; @@ -112,7 +112,7 @@ struct HoistIfThenElseConfigNode : public AttrsNodeReflAdapter R.Tensor((5, "b * 2"), dtype="float32"): b = T.int64() lv: R.Shape([5, b * 2]) = R.shape([5, b * 2]) diff --git a/tests/python/runtime/test_runtime_rpc.py b/tests/python/runtime/test_runtime_rpc.py index 6711ccf92f3f..e696cbcf086c 100644 --- a/tests/python/runtime/test_runtime_rpc.py +++ b/tests/python/runtime/test_runtime_rpc.py @@ -413,7 +413,6 @@ def check(client, is_local): get_elem = client.get_function("testing.GetShapeElem") get_size = client.get_function("testing.GetShapeSize") shape = make_shape(2, 3) - assert shape.type_key == "runtime.RPCObjectRef" assert get_elem(shape, 0) == 2 assert get_elem(shape, 1) == 3 assert get_size(shape) == 2