diff --git a/include/tvm/ir/global_var_supply.h b/include/tvm/ir/global_var_supply.h deleted file mode 100644 index 2241385167e2..000000000000 --- a/include/tvm/ir/global_var_supply.h +++ /dev/null @@ -1,128 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ - -/*! - * \file tvm/ir/global_var_supply.h - * \brief GlobalVarSupply that can be used to generate unique \class GlobalVar. - */ -#ifndef TVM_IR_GLOBAL_VAR_SUPPLY_H_ -#define TVM_IR_GLOBAL_VAR_SUPPLY_H_ - -#include -#include -#include -#include - -#include -#include - -namespace tvm { - -/*! - * \brief GlobalVarSupply can be used to generate unique GlobalVars. - */ -class GlobalVarSupplyNode : public ffi::Object { - public: - /*! - * \brief Empty constructor. Will use an empty NameSupply. - */ - GlobalVarSupplyNode() : GlobalVarSupplyNode(NameSupply()) {} - - /*! - * \brief Constructor. - * \param name_supply The NameSupply to use for generating the names of fresh GlobalVars. - * \param name_to_var_map An optional map. - */ - explicit GlobalVarSupplyNode(NameSupply name_supply, - std::unordered_map name_to_var_map = {}); - - /*! - * \brief Generates a unique GlobalVar from this supply. - * \param name The name from which the name of the GlobalVar is derived. - * \param add_prefix If set to true, then the prefix of the contained NameSupply will be prepended - * to the name. \return A unique GlobalVar. - */ - GlobalVar FreshGlobal(ffi::String name, bool add_prefix = true); - - /*! - * \brief Looks up for a GlobalVar with the given name in this supply. - * If no entry is found, creates one, places it in the cache and returns it. - * \param name The name of the GlobalVar to search for. - * \param add_prefix If set to true, the prefix of the contained NameSupply will be prepended to - * the name before performing the search. \return A cached GlobalVar. - */ - GlobalVar UniqueGlobalFor(const ffi::String& name, bool add_prefix = true); - - /*! - * \brief Reserves an existing GlobalVar with this supply. - * \param var The GlobalVar to be registered. - * \param allow_conflict Allow conflict with other GlobalVars that have the same name. - */ - void ReserveGlobalVar(const GlobalVar& var, bool allow_conflict = false); - - static void RegisterReflection() { - namespace refl = tvm::ffi::reflection; - refl::ObjectDef(); - } - - /*! \brief The NameSupply used to generate unique name hints to GlobalVars. */ - NameSupply name_supply_; - - static constexpr const bool _type_mutable = true; - TVM_FFI_DECLARE_OBJECT_INFO_FINAL("ir.GlobalVarSupply", GlobalVarSupplyNode, ffi::Object); - - private: - std::unordered_map name_to_var_map_; -}; - -/*! - * \brief Managed reference class to GlobalVarSupplyNode. - * \sa GlobalVarSupplyNode - */ -class GlobalVarSupply : public ffi::ObjectRef { - public: - /*! - * \brief Constructor. - * \param name_supply The NameSupply to be used when generating new GlobalVars. - * \param name_to_var_map An optional map. - */ - TVM_DLL explicit GlobalVarSupply(const NameSupply& name_supply = NameSupply(), - std::unordered_map name_to_var_map = {}); - - /*! - * \brief Constructs a supply from an array of IRModules. GlobalVars generated by this supply are - * guaranteed not to conflict with any GlobalVars that belong to the modules. \param modules Array - * of IRModules. - */ - TVM_DLL explicit GlobalVarSupply(const ffi::Array& modules); - - /*! - * \brief Constructs a GlobalVarSupply from an IRModule. GlobalVars generated by this supply are - * guaranteed not to conflict with GlobalVars that belong to the modules. \param module The - * IRModule. - */ - TVM_DLL explicit GlobalVarSupply(const IRModule module); - - TVM_FFI_DEFINE_OBJECT_REF_METHODS_NOTNULLABLE(GlobalVarSupply, ffi::ObjectRef, - GlobalVarSupplyNode); -}; - -} // namespace tvm - -#endif // TVM_IR_GLOBAL_VAR_SUPPLY_H_ diff --git a/include/tvm/ir/name_supply.h b/include/tvm/ir/name_supply.h deleted file mode 100644 index 54bac2afc3b5..000000000000 --- a/include/tvm/ir/name_supply.h +++ /dev/null @@ -1,169 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ - -/*! - * \file tvm/ir/name_supply.h - * \brief NameSupply that can be used to generate unique variable names. - */ -#ifndef TVM_IR_NAME_SUPPLY_H_ -#define TVM_IR_NAME_SUPPLY_H_ - -#include -#include - -#include -#include -#include -#include -#include - -namespace tvm { - -/*! - * \brief NameSupply can be used to generate unique names. - */ -class NameSupplyNode : public ffi::Object { - public: - /*! - * \brief Empty constructor. Needed by the TVM_REGISTER_NODE_TYPE macro. - */ - NameSupplyNode() = default; - - /*! - * \brief Constructor. - * \param prefix The prefix to be used with this NameSupply. - * \param name_map The map used to guarantee uniqueness. - */ - NameSupplyNode(const ffi::String& prefix, std::unordered_map name_map) - : prefix_(prefix), name_map(std::move(name_map)) {} - - /*! - * \brief Generates a unique name from this NameSupply. - * \param name The name from which the generated name is derived. - * \param add_prefix If set to true, then the prefix of this NameSupply will be prepended to the - * name. - * \param add_underscore If set to true, add '_' between prefix and a digit. - * \return A unique name. - */ - ffi::String FreshName(const ffi::String& name, bool add_prefix = true, - bool add_underscore = true); - - /*! - * \brief Reserves an existing name with this NameSupply. - * \param name The name to be reserved. - * \param add_prefix If set to true, then the prefix of this NameSupply will be prepended to the - * name before reserving it. \return The name that was reserved with the NameSupply. It can be - * different if a prefix is added. - */ - ffi::String ReserveName(const ffi::String& name, bool add_prefix = true); - - /*! - * \brief Checks if this NameSupply already generated a name. - * \param name The name to check. - * \param add_prefix If set to true, then the prefix of this NameSupply will be prepended to the - * name before checking for it. \return True if the name has already been generated. False - * otherwise. - */ - bool ContainsName(const ffi::String& name, bool add_prefix = true); - - // Prefix for all GlobalVar names. It can be empty. - std::string prefix_; - - static constexpr const bool _type_mutable = true; - - static void RegisterReflection() { - namespace refl = tvm::ffi::reflection; - refl::ObjectDef(); - } - - TVM_FFI_DECLARE_OBJECT_INFO_FINAL("ir.NameSupply", NameSupplyNode, ffi::Object); - - private: - /*! \brief Helper function to add the NameSupply prefix to the name. */ - ffi::String add_prefix_to_name(const ffi::String& name); - - /*! - * \brief Function that will generate a unique name. - * \param name The name to be used as a base. - * \param add_underscore If set to true, add '_' between prefix and a digit. - * \return A unique name. - */ - std::string GetUniqueName(std::string name, bool add_underscore = true); - - /*! \brief A map that is used to generate unique names. */ - std::unordered_map name_map; -}; - -/*! - * \brief Managed reference class to NameSupplyNode. - * \sa NameSupplyNode - */ -class NameSupply : public ffi::ObjectRef { - public: - /*! - * \brief Constructor. - * \param prefix The prefix to be used with this NameSupply. - * \param name_map An optional map. - */ - TVM_DLL explicit NameSupply(const ffi::String& prefix = "", - std::unordered_map name_map = {}); - - /*! - * \brief Construct NameSupply with a name map created from the given iterator range and - * the functor. - * - * The functor should return the name of the dereferenced object. - */ - template - TVM_DLL explicit NameSupply(Iter begin, Iter end, Lambda f) - : NameSupply("", GetNameMap(begin, end, f)) {} - - TVM_FFI_DEFINE_OBJECT_REF_METHODS_NOTNULLABLE(NameSupply, ffi::ObjectRef, NameSupplyNode); - - private: - template - static std::unordered_map GetNameMap(Iter begin, Iter end, Lambda f) { - // static_assert is more reader-friendly than SFINAE when template specialization is not needed. - static_assert(std::is_convertible::value, - "Lambda f must has a signature of [?](*it) -> string {}"); - std::unordered_map name_map; - for (auto it = begin; it != end; ++it) { - const std::string& name = f(*it); - const size_t idx_last_first_num = std::distance( - std::find_if(name.rbegin(), name.rend(), [](char c) { return !std::isdigit(c); }), - name.rend()); - // name = {O = others}{D = consecutive digits} - // let O -> prefix; - std::string prefix = name.substr(0, idx_last_first_num); - TVM_FFI_ICHECK(prefix.size() > 0 && std::isalpha(prefix[0])) - << "Invalid variable name: " << name; - if (0 == name_map.count(prefix)) name_map[prefix] = 0; - if (idx_last_first_num < name.size()) { // has some digits. - // let D's nearest natural number -> idx; - // note: stoul("000123") = 123; - name_map[prefix] = std::max(name_map[prefix], std::stoi(name.substr(idx_last_first_num))); - } - } - return name_map; - } -}; - -} // namespace tvm - -#endif // TVM_IR_NAME_SUPPLY_H_ diff --git a/include/tvm/ir/unique_name_supply.h b/include/tvm/ir/unique_name_supply.h new file mode 100644 index 000000000000..0f79318d1dc8 --- /dev/null +++ b/include/tvm/ir/unique_name_supply.h @@ -0,0 +1,143 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file tvm/ir/unique_name_supply.h + * \brief UniqueNameSupply that can be used to generate unique variable names. + */ +#ifndef TVM_IR_UNIQUE_NAME_SUPPLY_H_ +#define TVM_IR_UNIQUE_NAME_SUPPLY_H_ + +#include +#include + +#include +#include +#include + +namespace tvm { + +/*! + * \brief UniqueNameSupply can be used to generate unique names. + */ +class UniqueNameSupplyNode : public ffi::Object { + public: + /*! + * \brief Empty constructor. Needed by the TVM_REGISTER_NODE_TYPE macro. + */ + UniqueNameSupplyNode() = default; + + /*! + * \brief Constructor. + * \param prefix The prefix to be used with this UniqueNameSupply. + * \param name_map The map used to guarantee uniqueness. + */ + UniqueNameSupplyNode(const ffi::String& prefix, ffi::Map name_map) + : prefix_(prefix), name_map(std::move(name_map)) {} + + /*! + * \brief Generates a unique name from this UniqueNameSupply. + * \param name The name from which the generated name is derived. + * \param add_prefix If set to true, then the prefix of this UniqueNameSupply will be prepended to + * the name. + * \param add_underscore If set to true, add '_' between prefix and a digit. + * \return A unique name. + */ + ffi::String FreshName(const ffi::String& name, bool add_prefix = true, + bool add_underscore = true); + + /*! + * \brief Reserves an existing name with this UniqueNameSupply. + * \param name The name to be reserved. + * \param add_prefix If set to true, then the prefix of this UniqueNameSupply will be prepended to + * the name before reserving it. \return The name that was reserved with the UniqueNameSupply. It + * can be different if a prefix is added. + */ + ffi::String ReserveName(const ffi::String& name, bool add_prefix = true); + + /*! + * \brief Checks if this UniqueNameSupply already generated a name. + * \param name The name to check. + * \param add_prefix If set to true, then the prefix of this UniqueNameSupply will be prepended to + * the name before checking for it. \return True if the name has already been generated. False + * otherwise. + */ + bool ContainsName(const ffi::String& name, bool add_prefix = true); + + // Prefix for all GlobalVar names. It can be empty. + std::string prefix_; + + static constexpr const bool _type_mutable = true; + + static void RegisterReflection() { + namespace refl = tvm::ffi::reflection; + refl::ObjectDef(); + } + + TVM_FFI_DECLARE_OBJECT_INFO_FINAL("ir.UniqueNameSupply", UniqueNameSupplyNode, ffi::Object); + + private: + /*! \brief Helper function to add the UniqueNameSupply prefix to the name. */ + ffi::String AddPrefixToName(const ffi::String& name); + + /*! + * \brief Function that will generate a unique name. + * \param name The name to be used as a base. + * \param add_underscore If set to true, add '_' between prefix and a digit. + * \return A unique name. + */ + std::string GetUniqueName(std::string name, bool add_underscore = true); + + /*! \brief A map that is used to generate unique names. */ + ffi::Map name_map; +}; + +/*! + * \brief Managed reference class to UniqueNameSupplyNode. + * \sa UniqueNameSupplyNode + */ +class UniqueNameSupply : public ffi::ObjectRef { + public: + /*! + * \brief Constructor. + * \param prefix The prefix to be used with this UniqueNameSupply. + * \param name_map An optional map. + */ + TVM_DLL explicit UniqueNameSupply(const ffi::String& prefix = "", + ffi::Map name_map = {}); + + /*! + * \brief Construct UniqueNameSupply by reserving names from the given iterator range. + * + * The functor should return the name of the dereferenced object. + */ + template + TVM_DLL UniqueNameSupply(Iter begin, Iter end, Lambda f) : UniqueNameSupply("") { + for (auto it = begin; it != end; ++it) { + this->operator->()->ReserveName(f(*it), false); + } + } + + TVM_FFI_DEFINE_OBJECT_REF_METHODS_NOTNULLABLE(UniqueNameSupply, ffi::ObjectRef, + UniqueNameSupplyNode); +}; + +} // namespace tvm + +#endif // TVM_IR_UNIQUE_NAME_SUPPLY_H_ diff --git a/include/tvm/relax/binding_rewrite.h b/include/tvm/relax/binding_rewrite.h index 69092726b474..740e8ed01fda 100644 --- a/include/tvm/relax/binding_rewrite.h +++ b/include/tvm/relax/binding_rewrite.h @@ -25,7 +25,7 @@ #ifndef TVM_RELAX_BINDING_REWRITE_H_ #include -#include +#include #include #include @@ -87,7 +87,7 @@ class DataflowBlockRewriteNode : public ffi::Object { ffi::Array fn_outputs_; //!< Variables required by function outputs. private: - NameSupply name_supply_; //!< Name supply for tracking and generating unique names. + UniqueNameSupply name_supply_; //!< Unique name supply for tracking and generating unique names. }; /*! diff --git a/include/tvm/relax/block_builder.h b/include/tvm/relax/block_builder.h index 68d6fc7bfa2c..8413686dc9df 100644 --- a/include/tvm/relax/block_builder.h +++ b/include/tvm/relax/block_builder.h @@ -25,7 +25,7 @@ #define TVM_RELAX_BLOCK_BUILDER_H_ #include -#include +#include #include #include #include @@ -68,11 +68,11 @@ class BlockBuilderNode : public ffi::Object { // Global Context management //------------------------------- /*! - * \brief Get the name supply for generating unique names. + * \brief Get the unique name supply for generating unique names. * - * \return The name supply. + * \return The unique name supply. */ - virtual NameSupply name_supply() = 0; + virtual UniqueNameSupply name_supply() = 0; /*! * \brief Get the context IRModule in this builder. diff --git a/python/tvm/ir/supply.py b/python/tvm/ir/supply.py index 183e20f25789..07b91e1a86e9 100644 --- a/python/tvm/ir/supply.py +++ b/python/tvm/ir/supply.py @@ -14,18 +14,18 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. -"""Suppliers that are used to guarantee uniqueness of names and GlobalVars.""" +"""Suppliers that are used to guarantee uniqueness of names.""" import tvm_ffi -from tvm import IRModule, Object +from tvm import Object from . import _ffi_api -@tvm_ffi.register_object("ir.NameSupply") -class NameSupply(Object): - """NameSupply that can be used to generate unique names. +@tvm_ffi.register_object("ir.UniqueNameSupply") +class UniqueNameSupply(Object): + """UniqueNameSupply that can be used to generate unique names. Parameters ---------- @@ -33,10 +33,10 @@ class NameSupply(Object): """ def __init__(self, prefix=""): - self.__init_handle_by_constructor__(_ffi_api.NameSupply, prefix) + self.__init_handle_by_constructor__(_ffi_api.UniqueNameSupply, prefix) def fresh_name(self, name, add_prefix=True, add_underscore=True): - """Generates a unique name from this NameSupply. + """Generates a unique name from this UniqueNameSupply. Parameters ---------- @@ -44,15 +44,15 @@ def fresh_name(self, name, add_prefix=True, add_underscore=True): The name from which the generated name is derived. add_prefix: bool - If set to true, then the prefix of this NameSupply will be prepended to the name. + If set to true, then the prefix of this UniqueNameSupply will be prepended to the name. add_underscore: bool If set to True, adds '_' between prefix and digit. """ - return _ffi_api.NameSupply_FreshName(self, name, add_prefix, add_underscore) + return _ffi_api.UniqueNameSupply_FreshName(self, name, add_prefix, add_underscore) def reserve_name(self, name, add_prefix=True): - """Reserves an existing name with this NameSupply. + """Reserves an existing name with this UniqueNameSupply. Parameters ---------- @@ -60,13 +60,13 @@ def reserve_name(self, name, add_prefix=True): The name to be reserved. add_prefix: bool - If set to true, then the prefix of this NameSupply will be prepended to the name + If set to true, then the prefix of this UniqueNameSupply will be prepended to the name before reserving it. """ - return _ffi_api.NameSupply_ReserveName(self, name, add_prefix) + return _ffi_api.UniqueNameSupply_ReserveName(self, name, add_prefix) def contains_name(self, name, add_prefix=True): - """Checks if this NameSupply already generated a name. + """Checks if this UniqueNameSupply already generated a name. Parameters ---------- @@ -74,74 +74,7 @@ def contains_name(self, name, add_prefix=True): The name to check. add_prefix: bool - If set to true, then the prefix of this NameSupply will be prepended to the name + If set to true, then the prefix of this UniqueNameSupply will be prepended to the name before checking for it. """ - return _ffi_api.NameSupply_ContainsName(self, name, add_prefix) - - -@tvm_ffi.register_object("ir.GlobalVarSupply") -class GlobalVarSupply(Object): - """GlobalVarSupply that holds a mapping between names and GlobalVars. - - GlobalVarSupply can be used to generate new GlobalVars with a unique name. - It also can be used to retrieve previously generated GlobalVars based on a name. - - Parameters - ---------- - value: Union[List[IRModule], IRModule, NameSupply] - The IRModules used to build this GlobalVarSupply or a NameSupply. - """ - - def __init__(self, value=None): - if value is None: - name_supply = NameSupply("") - self.__init_handle_by_constructor__(_ffi_api.GlobalVarSupply_NameSupply, name_supply) - elif isinstance(value, NameSupply): - self.__init_handle_by_constructor__(_ffi_api.GlobalVarSupply_NameSupply, value) - elif isinstance(value, list | tvm_ffi.Array): - self.__init_handle_by_constructor__(_ffi_api.GlobalVarSupply_IRModules, value) - elif isinstance(value, IRModule): - self.__init_handle_by_constructor__(_ffi_api.GlobalVarSupply_IRModule, value) - - def fresh_global(self, name, add_prefix=True): - """Generates a unique GlobalVar from this supply. - - Parameters - ---------- - name: String - The name from which the name of the GlobalVar is derived. - - add_prefix: bool - If set to true, then the prefix of the contained NameSupply will be prepended - to the name. - """ - return _ffi_api.GlobalVarSupply_FreshGlobal(self, name, add_prefix) - - def unique_global_for(self, name, add_prefix=True): - """Looks up for a GlobalVar with the given name in this supply. If no entry is found - , creates one, places it in the cache and returns it. - - Parameters - ---------- - name: String - The name of the GlobalVar to search for. - - add_prefix: bool - If set to true, the prefix of the contained NameSupply will be prepended to the - name before performing the search. - """ - return _ffi_api.GlobalVarSupply_UniqueGlobalFor(self, name, add_prefix) - - def reserve_global(self, global_var, allow_conflict=False): - """Reserves an existing GlobalVar with this supply. - - Parameters - ---------- - global_var: GlobalVar - The GlobalVar to be registered. - - allow_conflict: bool - Allow conflict with other GlobalVars that have the same name - """ - return _ffi_api.GlobalVarSupply_ReserveGlobalVar(self, global_var, allow_conflict) + return _ffi_api.UniqueNameSupply_ContainsName(self, name, add_prefix) diff --git a/python/tvm/relax/frontend/onnx/onnx_frontend.py b/python/tvm/relax/frontend/onnx/onnx_frontend.py index d64020bfc772..1da471e03c75 100644 --- a/python/tvm/relax/frontend/onnx/onnx_frontend.py +++ b/python/tvm/relax/frontend/onnx/onnx_frontend.py @@ -59,7 +59,7 @@ import tvm from tvm import relax, tirx, topi from tvm.ir import IRModule -from tvm.ir.supply import NameSupply +from tvm.ir.supply import UniqueNameSupply from tvm.runtime import DataType, DataTypeCode from tvm.tirx.generic import cast from tvm.topi.utils import get_const_tuple @@ -5337,7 +5337,7 @@ def __init__( self._input_names: list[str] = [] self._dtype = dtype_dict self.opset: int = None - self._name_supply = NameSupply() + self._name_supply = UniqueNameSupply() self._keep_params_in_input = keep_params_in_input self._sanitize: bool = sanitize self.bb: relax.BlockBuilder = relax.BlockBuilder() # pylint: disable=invalid-name diff --git a/python/tvm/runtime/__init__.py b/python/tvm/runtime/__init__.py index ee5f3e1dd43c..c51cb05dc4e8 100644 --- a/python/tvm/runtime/__init__.py +++ b/python/tvm/runtime/__init__.py @@ -21,7 +21,7 @@ from tvm_ffi._dtype import dtype as DataType, DataTypeCode # Import _ffi_node_api for its side effect of installing AsRepr as -# tvm_ffi.core.__object_repr__ so TVM IR objects use the rich C++ ReprPrinter. +# tvm_ffi.core.__object_repr__. from . import _ffi_node_api # class exposures diff --git a/python/tvm/runtime/_ffi_node_api.py b/python/tvm/runtime/_ffi_node_api.py index 18af61ec7563..1c87b989b69f 100644 --- a/python/tvm/runtime/_ffi_node_api.py +++ b/python/tvm/runtime/_ffi_node_api.py @@ -40,8 +40,5 @@ def LoadJSON(json_str): # Exports functions registered in node namespace. tvm_ffi.init_ffi_api("node", __name__) -# Override the default repr function for tvm_ffi.core.Object so TVM IR -# objects use the rich C++ ReprPrinter (registered above via init_ffi_api), -# falling back to the runtime-only AsRepr defined in this file when libtvm.so -# is not available. +# Override the default repr function for tvm_ffi.core.Object. tvm_ffi.core.__object_repr__ = AsRepr diff --git a/src/ir/access_path_repr.cc b/src/ir/access_path_repr.cc deleted file mode 100644 index 8225452c5446..000000000000 --- a/src/ir/access_path_repr.cc +++ /dev/null @@ -1,49 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ - -/*! - * \file ir/access_path_repr.cc - * \brief FFI registration for ffi-repr-based printing. - * - * This file: - * - Registers node.AsRepr (for backward Python compatibility) via ffi::ReprPrint. - * - * Note: __ffi_repr__ hooks for ffi::reflection::AccessPath and AccessStep are - * registered by tvm-ffi itself (src/ffi/extra/reflection_extra.cc, landed in - * apache/tvm-ffi#598). The duplicate registrations that previously lived here - * were removed when bumping tvm-ffi to 59da4c0 to avoid a double-registration - * abort at library load time. - * - * Note: tvm::Dump() has been removed (zero in-tree callers). Use - * tvm::ffi::ReprPrint(any) directly from gdb instead. - */ -#include -#include -#include - -namespace tvm { - -TVM_FFI_STATIC_INIT_BLOCK() { - namespace refl = tvm::ffi::reflection; - // node.AsRepr: backward-compatible Python entry point. - // Python's tvm.runtime._ffi_node_api sets __object_repr__ = AsRepr via init_ffi_api. - refl::GlobalDef().def("node.AsRepr", - [](ffi::Any obj) -> ffi::String { return ffi::ReprPrint(obj); }); -} -} // namespace tvm diff --git a/src/ir/global_var_supply.cc b/src/ir/global_var_supply.cc deleted file mode 100644 index 700c3ef84038..000000000000 --- a/src/ir/global_var_supply.cc +++ /dev/null @@ -1,111 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ - -/*! - * \file global_var_supply.cc - * \brief GlobalVarSupply that can be used to generate unique GlobalVars. - */ -#include "tvm/ir/global_var_supply.h" - -#include -#include - -#include - -#include "tvm/ir/expr.h" - -namespace tvm { - -TVM_FFI_STATIC_INIT_BLOCK() { GlobalVarSupplyNode::RegisterReflection(); } - -GlobalVarSupply::GlobalVarSupply(const NameSupply& name_supply, - std::unordered_map name_to_var_map) { - auto n = ffi::make_object(name_supply, name_to_var_map); - data_ = std::move(n); -} - -std::string GetModuleName(const IRModule& module) { - return module->GetAttr(tvm::attr::kModuleName).value_or("tvmgen_default"); -} - -GlobalVarSupply::GlobalVarSupply(const ffi::Array& modules) : GlobalVarSupply() { - if (!modules.empty()) { - IRModule first_mod = modules.front(); - this->operator->()->name_supply_->prefix_ = GetModuleName(first_mod); - } - for (auto& mod : modules) { - for (auto kv : mod->functions) { - this->operator->()->ReserveGlobalVar(kv.first); - } - } -} - -GlobalVarSupply::GlobalVarSupply(const IRModule module) - : GlobalVarSupply(ffi::Array{module}) {} - -void GlobalVarSupplyNode::ReserveGlobalVar(const GlobalVar& var, bool allow_conflict) { - name_supply_->ReserveName(var->name_hint, false); - if (!allow_conflict) { - TVM_FFI_ICHECK(name_to_var_map_.count(var->name_hint) == 0) - << "GlobalVar " << var << " conflicts by name in this supply."; - } - name_to_var_map_[var->name_hint] = var; -} - -GlobalVarSupplyNode::GlobalVarSupplyNode(NameSupply name_supply, - std::unordered_map name_to_var_map) - : name_supply_(std::move(name_supply)), name_to_var_map_(std::move(name_to_var_map)) {} - -GlobalVar GlobalVarSupplyNode::UniqueGlobalFor(const ffi::String& name, bool add_prefix) { - ffi::String final_name = name_supply_->ReserveName(name, add_prefix); - - auto it = name_to_var_map_.find(final_name); - if (it != name_to_var_map_.end()) { - return it->second; - } else { - GlobalVar var = GlobalVar(final_name); - name_to_var_map_.emplace(final_name, var); - return var; - } -} - -GlobalVar GlobalVarSupplyNode::FreshGlobal(ffi::String name, bool add_prefix) { - ffi::String final_name = name_supply_->FreshName(name, add_prefix); - TVM_FFI_ICHECK(name_to_var_map_.find(final_name) == name_to_var_map_.end()) - << "GlobalVar already exists for name " << final_name; - GlobalVar var = GlobalVar(final_name); - name_to_var_map_.emplace(final_name, var); - return var; -} - -TVM_FFI_STATIC_INIT_BLOCK() { - namespace refl = tvm::ffi::reflection; - refl::GlobalDef() - .def("ir.GlobalVarSupply_NameSupply", - [](const NameSupply& name_supply) { return GlobalVarSupply(name_supply); }) - .def("ir.GlobalVarSupply_IRModule", - [](IRModule mod) { return GlobalVarSupply(std::move(mod)); }) - .def("ir.GlobalVarSupply_IRModules", - [](const ffi::Array& mods) { return GlobalVarSupply(mods); }) - .def_method("ir.GlobalVarSupply_FreshGlobal", &GlobalVarSupplyNode::FreshGlobal) - .def_method("ir.GlobalVarSupply_UniqueGlobalFor", &GlobalVarSupplyNode::UniqueGlobalFor) - .def_method("ir.GlobalVarSupply_ReserveGlobalVar", &GlobalVarSupplyNode::ReserveGlobalVar); -} - -} // namespace tvm diff --git a/src/ir/module.cc b/src/ir/module.cc index a09780d94dc5..156ca17c1255 100644 --- a/src/ir/module.cc +++ b/src/ir/module.cc @@ -28,9 +28,9 @@ #include #include #include -#include #include #include +#include #include #include @@ -219,13 +219,17 @@ IRModule IRModule::FromExpr(const RelaxExpr& expr, } } + UniqueNameSupply global_names(mod->functions.begin(), mod->functions.end(), + [](const auto& kv) { return kv.first->name_hint; }); GlobalVar main_gv; - auto global_var_supply = GlobalVarSupply(mod); if (gv_name.empty()) { // Bind function to 'main' (though rename if would clash with existing 'main'). - main_gv = global_var_supply->FreshGlobal("main", false); + main_gv = GlobalVar(global_names->FreshName("main", false)); + } else if (mod->ContainGlobalVar(gv_name)) { + main_gv = mod->GetGlobalVar(gv_name); } else { - main_gv = global_var_supply->UniqueGlobalFor(gv_name, false); + global_names->ReserveName(gv_name, false); + main_gv = GlobalVar(gv_name); } mod->Add(main_gv, func); return mod; diff --git a/src/ir/name_supply.cc b/src/ir/name_supply.cc deleted file mode 100644 index 2f7bf501e55a..000000000000 --- a/src/ir/name_supply.cc +++ /dev/null @@ -1,108 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ - -/*! - * \file name_supply.cc - * \brief NameSupply that can be used to generate unique variable names. - */ -#include "tvm/ir/name_supply.h" - -#include -#include - -#include - -namespace tvm { - -NameSupply::NameSupply(const ffi::String& prefix, std::unordered_map name_map) { - auto n = ffi::make_object(prefix, std::move(name_map)); - data_ = std::move(n); -} - -ffi::String NameSupplyNode::ReserveName(const ffi::String& name, bool add_prefix) { - ffi::String final_name = name; - if (add_prefix) { - final_name = add_prefix_to_name(name); - } - name_map[final_name] = 0; - return final_name; -} - -ffi::String NameSupplyNode::FreshName(const ffi::String& name, bool add_prefix, - bool add_underscore) { - ffi::String unique_name = name; - if (unique_name.empty()) { - // Special case for empty name, set to "v". - unique_name = "v"; - } - if (add_prefix) { - unique_name = add_prefix_to_name(unique_name); - } - unique_name = GetUniqueName(unique_name, add_underscore); - return unique_name; -} - -bool NameSupplyNode::ContainsName(const ffi::String& name, bool add_prefix) { - ffi::String unique_name = name; - if (add_prefix) { - unique_name = add_prefix_to_name(name); - } - - return name_map.count(unique_name); -} - -ffi::String NameSupplyNode::add_prefix_to_name(const ffi::String& name) { - if (prefix_.empty()) { - return name; - } - - std::ostringstream ss; - ss << prefix_ << "_" << name; - return ss.str(); -} - -std::string NameSupplyNode::GetUniqueName(std::string name, bool add_underscore) { - for (size_t i = 0; i < name.size(); ++i) { - if (name[i] == '.') name[i] = '_'; - } - auto it = name_map.find(name); - if (it != name_map.end()) { - auto new_name = name; - while (!name_map.insert({new_name, 0}).second) { - std::ostringstream os; - os << name << (add_underscore ? "_" : "") << (++it->second); - new_name = os.str(); - } - return new_name; - } - name_map[name] = 0; - return name; -} - -TVM_FFI_STATIC_INIT_BLOCK() { - namespace refl = tvm::ffi::reflection; - NameSupplyNode::RegisterReflection(); - refl::GlobalDef() - .def("ir.NameSupply", [](ffi::String prefix) { return NameSupply(prefix); }) - .def_method("ir.NameSupply_FreshName", &NameSupplyNode::FreshName) - .def_method("ir.NameSupply_ReserveName", &NameSupplyNode::ReserveName) - .def_method("ir.NameSupply_ContainsName", &NameSupplyNode::ContainsName); -} - -} // namespace tvm diff --git a/src/ir/unique_name_supply.cc b/src/ir/unique_name_supply.cc new file mode 100644 index 000000000000..481edcac89ce --- /dev/null +++ b/src/ir/unique_name_supply.cc @@ -0,0 +1,114 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file unique_name_supply.cc + * \brief UniqueNameSupply that can be used to generate unique variable names. + */ +#include "tvm/ir/unique_name_supply.h" + +#include +#include + +#include +#include + +namespace tvm { + +UniqueNameSupply::UniqueNameSupply(const ffi::String& prefix, + ffi::Map name_map) { + if (!name_map.defined()) { + name_map = ffi::Map(); + } + auto n = ffi::make_object(prefix, std::move(name_map)); + data_ = std::move(n); +} + +ffi::String UniqueNameSupplyNode::ReserveName(const ffi::String& name, bool add_prefix) { + ffi::String final_name = name; + if (add_prefix) { + final_name = AddPrefixToName(name); + } + name_map.Set(final_name, 0); + return final_name; +} + +ffi::String UniqueNameSupplyNode::FreshName(const ffi::String& name, bool add_prefix, + bool add_underscore) { + ffi::String unique_name = name; + if (unique_name.empty()) { + unique_name = "v"; + } + if (add_prefix) { + unique_name = AddPrefixToName(unique_name); + } + return GetUniqueName(unique_name, add_underscore); +} + +bool UniqueNameSupplyNode::ContainsName(const ffi::String& name, bool add_prefix) { + ffi::String unique_name = name; + if (add_prefix) { + unique_name = AddPrefixToName(name); + } + return name_map.count(unique_name); +} + +ffi::String UniqueNameSupplyNode::AddPrefixToName(const ffi::String& name) { + if (prefix_.empty()) { + return name; + } + + std::ostringstream ss; + ss << prefix_ << "_" << name; + return ss.str(); +} + +std::string UniqueNameSupplyNode::GetUniqueName(std::string name, bool add_underscore) { + for (size_t i = 0; i < name.size(); ++i) { + if (name[i] == '.') name[i] = '_'; + } + ffi::String name_key = name; + auto it = name_map.find(name_key); + if (it != name_map.end()) { + auto new_name = name; + int64_t suffix = (*it).second; + while (name_map.count(ffi::String(new_name))) { + std::ostringstream os; + os << name << (add_underscore ? "_" : "") << (++suffix); + new_name = os.str(); + } + name_map.Set(name_key, suffix); + name_map.Set(ffi::String(new_name), 0); + return new_name; + } + name_map.Set(name_key, 0); + return name; +} + +TVM_FFI_STATIC_INIT_BLOCK() { + namespace refl = tvm::ffi::reflection; + UniqueNameSupplyNode::RegisterReflection(); + refl::GlobalDef() + .def("ir.UniqueNameSupply", [](ffi::String prefix) { return UniqueNameSupply(prefix); }) + .def_method("ir.UniqueNameSupply_FreshName", &UniqueNameSupplyNode::FreshName) + .def_method("ir.UniqueNameSupply_ReserveName", &UniqueNameSupplyNode::ReserveName) + .def_method("ir.UniqueNameSupply_ContainsName", &UniqueNameSupplyNode::ContainsName); +} + +} // namespace tvm diff --git a/src/relax/backend/contrib/cutlass/codegen.cc b/src/relax/backend/contrib/cutlass/codegen.cc index 91840f6936e5..6de72397dc52 100644 --- a/src/relax/backend/contrib/cutlass/codegen.cc +++ b/src/relax/backend/contrib/cutlass/codegen.cc @@ -25,7 +25,7 @@ #include #include #include -#include +#include #include #include #include @@ -333,8 +333,8 @@ class CodegenCutlass : public relax::MemoizedExprTranslator, * name_hint. */ std::unordered_map var_name_map_; - /*! \brief A name supply to generate a unique name for each parameter. */ - NameSupply name_sup_; + /*! \brief A unique name supply to generate a unique name for each parameter. */ + UniqueNameSupply name_sup_; }; class CutlassModuleCodegen { diff --git a/src/relax/ir/binding_rewrite.cc b/src/relax/ir/binding_rewrite.cc index 9fad59f4e374..85fcfef1ea56 100644 --- a/src/relax/ir/binding_rewrite.cc +++ b/src/relax/ir/binding_rewrite.cc @@ -48,8 +48,8 @@ DataflowBlockRewrite::DataflowBlockRewrite(DataflowBlock dfb, Function root_fn) auto p = FunctionUseDef(root_fn); n->to_users_ = std::move(p.first); n->fn_outputs_ = std::move(p.second); - n->name_supply_ = NameSupply(n->to_users_.begin(), n->to_users_.end(), - [](const auto& p) { return p.first->name_hint(); }); + n->name_supply_ = UniqueNameSupply(n->to_users_.begin(), n->to_users_.end(), + [](const auto& p) { return p.first->name_hint(); }); data_ = std::move(n); } diff --git a/src/relax/ir/block_builder.cc b/src/relax/ir/block_builder.cc index 344b09024e59..f9360c6c4246 100644 --- a/src/relax/ir/block_builder.cc +++ b/src/relax/ir/block_builder.cc @@ -70,7 +70,7 @@ class BlockBuilderImpl : public BlockBuilderNode { //------------------------------- // Global Context management //------------------------------- - NameSupply name_supply() final { return name_supply_; } + UniqueNameSupply name_supply() final { return name_supply_; } IRModule GetContextIRModule() const final { return context_mod_; } @@ -346,8 +346,8 @@ class BlockBuilderImpl : public BlockBuilderNode { /*! \brief A binding table that maps var to value. */ std::unordered_map binding_table_; - /*! \brief A name supply to get unique names for IR construction. */ - NameSupply name_supply_; + /*! \brief A unique name supply to get unique names for IR construction. */ + UniqueNameSupply name_supply_; /*! \brief The IRModule being built by the BlockBuilder. */ IRModule context_mod_; diff --git a/src/relax/ir/dataflow_expr_rewriter.cc b/src/relax/ir/dataflow_expr_rewriter.cc index 9efa92bd8490..625ae1e76416 100644 --- a/src/relax/ir/dataflow_expr_rewriter.cc +++ b/src/relax/ir/dataflow_expr_rewriter.cc @@ -154,7 +154,7 @@ void RewriteSpec::Append(RewriteSpec other) { return; } - NameSupply gvar_name_supply(""); + UniqueNameSupply gvar_name_supply(""); for (const auto& [gvar, func] : new_subroutines) { gvar_name_supply->ReserveName(gvar->name_hint); } diff --git a/src/relax/transform/allocate_workspace.cc b/src/relax/transform/allocate_workspace.cc index 718214d49157..6bbe86d148f9 100644 --- a/src/relax/transform/allocate_workspace.cc +++ b/src/relax/transform/allocate_workspace.cc @@ -25,7 +25,7 @@ #include #include -#include +#include #include #include #include @@ -96,7 +96,7 @@ class ExternFunctionRewriter : ExprMutator { } private: - NameSupply name_sup_; + UniqueNameSupply name_sup_; /*! \brief A variable that represents the workspace parameter passed from main. */ Var workspace_var_param_; size_t max_workspace_size_ = 0; diff --git a/src/relax/transform/normalize.cc b/src/relax/transform/normalize.cc index 7c8f6b3854a3..ac3f0611db48 100644 --- a/src/relax/transform/normalize.cc +++ b/src/relax/transform/normalize.cc @@ -268,7 +268,7 @@ class GlobalVarNormalizer : private ExprMutator { } IRModule module_; - NameSupply name_supply_; + UniqueNameSupply name_supply_; ffi::Map gvar_map_; }; diff --git a/src/target/source/codegen_c.cc b/src/target/source/codegen_c.cc index 8762c83ee4f7..ddedd9ee355a 100644 --- a/src/target/source/codegen_c.cc +++ b/src/target/source/codegen_c.cc @@ -24,7 +24,7 @@ #include #include -#include +#include #include #include diff --git a/src/target/source/codegen_c.h b/src/target/source/codegen_c.h index 934d1af83a36..a023277ed19c 100644 --- a/src/target/source/codegen_c.h +++ b/src/target/source/codegen_c.h @@ -357,8 +357,8 @@ class CodeGenC : public ExprFunctor, */ std::unordered_map internal_functions_; - /* \brief Name supply to generate unique function names */ - NameSupply func_name_supply_; + /* \brief Unique unique name supply to generate unique function names */ + UniqueNameSupply func_name_supply_; }; } // namespace codegen diff --git a/src/target/source/codegen_source_base.cc b/src/target/source/codegen_source_base.cc index 5a07e3c7aa07..2646a6597ef4 100644 --- a/src/target/source/codegen_source_base.cc +++ b/src/target/source/codegen_source_base.cc @@ -28,7 +28,7 @@ namespace tvm { namespace codegen { void CodeGenSourceBase::ClearFuncState() { - name_supply_ = NameSupply(); + name_supply_ = UniqueNameSupply(); ssa_assign_map_.clear(); var_idmap_.clear(); scope_mark_.clear(); diff --git a/src/target/source/codegen_source_base.h b/src/target/source/codegen_source_base.h index 9283944c1b0d..f6e58cc9efba 100644 --- a/src/target/source/codegen_source_base.h +++ b/src/target/source/codegen_source_base.h @@ -25,7 +25,7 @@ #ifndef TVM_TARGET_SOURCE_CODEGEN_SOURCE_BASE_H_ #define TVM_TARGET_SOURCE_CODEGEN_SOURCE_BASE_H_ -#include +#include #include #include #include @@ -123,8 +123,8 @@ class CodeGenSourceBase { std::ostringstream fwd_decl_stream; /*! \brief name of each variable */ std::unordered_map var_idmap_; - /*! \brief NameSupply for allocation */ - NameSupply name_supply_; + /*! \brief Unique name supply for allocation */ + UniqueNameSupply name_supply_; /*! \brief The current indentation value */ int indent_{0}; diff --git a/src/te/operation/create_primfunc.cc b/src/te/operation/create_primfunc.cc index a4ce62812a08..5a7223430ed5 100644 --- a/src/te/operation/create_primfunc.cc +++ b/src/te/operation/create_primfunc.cc @@ -23,7 +23,7 @@ #include #include #include -#include +#include #include #include #include @@ -112,8 +112,8 @@ struct CreateFuncInfo { ProducerToBufferTransformer transformer; /*! \brief The buffers should be allocated at function root. */ ffi::Array root_alloc; - /*! \brief The NameSupply to make block name unique. */ - NameSupply name_supply; + /*! \brief The unique name supply to make block name unique. */ + UniqueNameSupply name_supply; ffi::String FreshName(ffi::String base_name) { return name_supply->FreshName(base_name); } diff --git a/src/tirx/ir/index_map.cc b/src/tirx/ir/index_map.cc index b26ccca248d6..cde0370f7f9d 100644 --- a/src/tirx/ir/index_map.cc +++ b/src/tirx/ir/index_map.cc @@ -26,7 +26,7 @@ #include #include #include -#include +#include #include #include #include @@ -347,7 +347,7 @@ IndexMap IndexMap::RenameVariables( const std::function(const Var& var)>& f_name_map) const { std::unordered_set used_names; ffi::Map var_remap; - NameSupply name_supply; + UniqueNameSupply name_supply; const IndexMapNode* n = this->get(); if (f_name_map != nullptr) { // Collect variables with pre-defined names provided by f_name_map. diff --git a/src/tirx/transform/bind_target.cc b/src/tirx/transform/bind_target.cc index 7a5627c80bcb..16bf74015200 100644 --- a/src/tirx/transform/bind_target.cc +++ b/src/tirx/transform/bind_target.cc @@ -36,7 +36,7 @@ #include #include -#include +#include #include #include #include @@ -261,7 +261,8 @@ IRModule BindTarget(IRModule mod, const Target& target) { // Track duplicated functions for call replacement ffi::Map host_function_replacements; - GlobalVarSupply gvar_supply(new_mod); + UniqueNameSupply global_names(new_mod->functions.begin(), new_mod->functions.end(), + [](const auto& kv) { return kv.first->name_hint; }); for (auto [gvar, func] : mod->functions) { const auto* prim_func_node = func.as(); @@ -313,7 +314,7 @@ IRModule BindTarget(IRModule mod, const Target& target) { // Create duplicate with host target for host callers host_func = WithAttr(std::move(host_func), tvm::attr::kTarget, target_host); ffi::String host_func_name = gvar->name_hint + "_host"; - GlobalVar host_gvar = gvar_supply->FreshGlobal(host_func_name, false); + GlobalVar host_gvar = GlobalVar(global_names->FreshName(host_func_name, false)); new_mod->Add(host_gvar, host_func); host_function_replacements.Set(gvar, host_gvar); diff --git a/src/tirx/transform/split_host_device.cc b/src/tirx/transform/split_host_device.cc index 7ec104765f3a..acc5e473afb8 100644 --- a/src/tirx/transform/split_host_device.cc +++ b/src/tirx/transform/split_host_device.cc @@ -24,8 +24,8 @@ #include #include #include -#include #include +#include #include #include #include @@ -678,7 +678,8 @@ namespace transform { Pass SplitHostDevice() { auto pass_func = [](IRModule mod, PassContext ctx) { - GlobalVarSupply global_var_supply(mod); + UniqueNameSupply global_names(mod->functions.begin(), mod->functions.end(), + [](const auto& kv) { return kv.first->name_hint; }); IRModule device_mod = IRModule(ffi::Map({})); IRModule updates = IRModule(ffi::Map({})); @@ -691,8 +692,8 @@ Pass SplitHostDevice() { auto global_symbol = func->GetAttr(tvm::attr::kGlobalSymbol); auto name_prefix = global_symbol.value_or(gvar->name_hint); auto kernel_name = name_prefix + "_kernel"; - auto var_supply = [&global_var_supply, &kernel_name]() -> GlobalVar { - return global_var_supply->FreshGlobal(kernel_name, false); + auto var_supply = [&global_names, &kernel_name]() -> GlobalVar { + return GlobalVar(global_names->FreshName(kernel_name, false)); }; func = SplitHostDevice(std::move(func), &device_mod, var_supply); diff --git a/tests/python/ir/test_name_supply.py b/tests/python/ir/test_unique_name_supply.py similarity index 62% rename from tests/python/ir/test_name_supply.py rename to tests/python/ir/test_unique_name_supply.py index bc3283968d3f..f440301e1feb 100644 --- a/tests/python/ir/test_name_supply.py +++ b/tests/python/ir/test_unique_name_supply.py @@ -16,12 +16,17 @@ # under the License. import tvm import tvm.testing -from tvm.ir.supply import NameSupply +from tvm import relax as rx +from tvm.ir.supply import UniqueNameSupply + + +def _empty_relax_func(): + return rx.Function([], rx.Tuple([])) def test_fresh_name_empty_string(): """Empty name should produce a valid variable name, not an empty string.""" - ns = NameSupply("") + ns = UniqueNameSupply("") name = ns.fresh_name("", add_prefix=False) assert name == "v" name2 = ns.fresh_name("", add_prefix=False) @@ -30,12 +35,28 @@ def test_fresh_name_empty_string(): def test_fresh_name_empty_string_with_prefix(): """Empty name with prefix should produce a valid variable name.""" - ns = NameSupply("prefix") + ns = UniqueNameSupply("prefix") name = ns.fresh_name("", add_prefix=True) assert name == "prefix_v" name2 = ns.fresh_name("", add_prefix=True) assert name2 == "prefix_v_1" +def test_ir_module_from_expr_freshens_main_collision(): + main_gv = tvm.ir.GlobalVar("main") + mod = tvm.IRModule.from_expr(_empty_relax_func(), {main_gv: _empty_relax_func()}) + + assert sorted(gvar.name_hint for gvar in mod.get_global_vars()) == ["main", "main_1"] + + +def test_ir_module_from_expr_reuses_existing_global_symbol(): + foo_gv = tvm.ir.GlobalVar("foo") + func = _empty_relax_func().with_attr("global_symbol", "foo") + mod = tvm.IRModule.from_expr(func, {foo_gv: _empty_relax_func()}) + + assert mod.get_global_var("foo").same_as(foo_gv) + assert [gvar.name_hint for gvar in mod.get_global_vars()] == ["foo"] + + if __name__ == "__main__": tvm.testing.main()