From 23ea1ed6e0046beb7ecb7d9c1002e11646d169b3 Mon Sep 17 00:00:00 2001 From: tqchen Date: Thu, 30 Apr 2026 13:07:57 +0000 Subject: [PATCH 1/4] [REFACTOR][FFI] Drop ffi indirection aliases in tvm headers Remove pure type re-export aliases (`using X = ffi::Y`) from tvm include/ headers and qualify all call sites to use the canonical ffi names directly: - `MemoryScope = ffi::String` in global_info.h and virtual_device.h - `Region = ffi::Array` in tirx/var.h (+ `using tirx::Region` pull-in from arith/bound.h) - `AccessPath = ffi::reflection::AccessPath` in script/printer/doc.h and ir_docsifier.h - `tvm_index_t = ffi::Shape::index_type` in runtime/data_type.h - `FCallPacked = ffi::String` in relax/op_attr_types.h - `TGlobalSymbol = ffi::String` and `TScriptPrinterName = ffi::String` in tirx/op_attr_types.h (also the TVM_TIR_REGISTER_OP macro key) - `Container = ffi::TensorObj` in runtime/tensor.h Remove free-function pull-ins via `using ffi::X`: - `using ffi::EnvErrorAlreadySet` from runtime/logging.h - `using ffi::GetDataSize/IsAligned/IsContiguous` from runtime/tensor.h - `using ffi::DLDataTypeToString/StringToDLDataType` from runtime/data_type.h - `using ffi::Any/Function/PackedArgs` from target/codegen.h Qualify all ~120 call sites across ~100 files to use the fully qualified `ffi::*` names. Function-signature aliases (FCompute, FReduce, meta_schedule Fxxx, NodeFunctor FType, etc.) are kept as they name callback signatures, not type re-exports. Also adds explicit `#include ` to files that use LOG/VLOG/DLOG macros and also had alias rewrites, to ensure each commit is independently buildable once the transitive dependency on logging.h through ir/op.h is removed by the follow-on logging include cleanup commit. --- include/tvm/arith/bound.h | 5 +- include/tvm/ir/global_info.h | 9 +- include/tvm/relax/attrs/op.h | 2 +- include/tvm/relax/op_attr_types.h | 6 - include/tvm/runtime/data_type.h | 13 +- include/tvm/runtime/logging.h | 2 - include/tvm/runtime/tensor.h | 5 - include/tvm/script/printer/doc.h | 22 +-- include/tvm/script/printer/ir_docsifier.h | 14 +- include/tvm/target/codegen.h | 4 - include/tvm/target/virtual_device.h | 21 +-- include/tvm/tirx/op.h | 2 +- include/tvm/tirx/op_attr_types.h | 10 -- include/tvm/tirx/var.h | 2 +- src/arith/domain_touched.cc | 9 +- src/ir/global_info.cc | 4 +- src/ir/script_printer.cc | 10 +- src/ir/structural_hash.cc | 4 +- .../contrib/codegen_json/codegen_json.h | 2 +- src/relax/backend/contrib/nnapi/codegen.cc | 2 +- src/relax/backend/vm/codegen_vm.cc | 2 +- src/relax/op/op.cc | 6 +- src/relax/op/tensor/set.cc | 4 +- src/relax/script/printer/binding.cc | 15 +- src/relax/script/printer/call.cc | 29 +-- src/relax/script/printer/distributed.cc | 13 +- src/relax/script/printer/expr.cc | 23 +-- src/relax/script/printer/function.cc | 165 +++++++++--------- src/relax/script/printer/region.cc | 24 +-- src/relax/script/printer/struct_info.cc | 23 +-- src/relax/script/printer/tir.cc | 29 +-- src/relax/script/printer/type.cc | 16 +- src/relax/script/printer/utils.h | 8 +- src/relax/transform/fuse_tir.cc | 8 +- src/relax/transform/legalize_ops.cc | 5 +- src/runtime/contrib/cudnn/conv_backward.cc | 5 +- src/runtime/contrib/cudnn/conv_forward.cc | 3 +- src/runtime/contrib/nnapi/nnapi_builder.cc | 4 +- src/runtime/contrib/nnapi/nnapi_ops.cc | 2 +- src/runtime/contrib/nnapi/nnapi_runtime.cc | 4 +- src/runtime/contrib/sort/sort.cc | 6 +- .../contrib/tensorrt/tensorrt_builder.cc | 5 +- src/runtime/cuda/cuda_device_api.cc | 11 +- src/runtime/device_api.cc | 6 +- src/runtime/hexagon/hexagon_device_api.cc | 8 +- src/runtime/metadata.h | 4 +- src/runtime/opencl/opencl_device_api.cc | 7 +- src/runtime/rpc/rpc_device_api.cc | 6 +- src/runtime/rpc/rpc_local_session.cc | 2 +- src/runtime/rpc/rpc_module.cc | 2 +- src/runtime/tensor.cc | 16 +- src/runtime/vm/builtin.cc | 6 +- src/runtime/vm/hexagon/builtin.cc | 6 +- src/runtime/vm/tensor_cache_support.cc | 2 +- src/runtime/vm/vm.cc | 3 +- src/s_tir/analysis/is_pure_function.cc | 7 +- .../analysis/sblock_access_region_detector.cc | 6 +- src/s_tir/meta_schedule/arg_info.cc | 4 +- .../multi_level_tiling_tensor_core.cc | 2 +- src/s_tir/schedule/primitive/cache_index.cc | 2 +- .../schedule/primitive/cache_read_write.cc | 22 +-- .../schedule/primitive/decompose_padding.cc | 2 +- .../schedule/primitive/rolling_buffer.cc | 4 +- src/s_tir/support/nd_int_set.h | 2 +- src/s_tir/transform/compact_buffer_region.cc | 16 +- .../transform/inject_software_pipeline.cc | 6 +- src/s_tir/transform/lower_match_buffer.cc | 3 +- .../transform/memhammer_lower_auto_copy.cc | 2 +- src/script/printer/doc.cc | 6 +- .../printer/doc_printer/base_doc_printer.cc | 8 +- .../printer/doc_printer/base_doc_printer.h | 4 +- .../printer/doc_printer/python_doc_printer.cc | 2 +- src/script/printer/ir/distributed.cc | 2 +- src/script/printer/ir/ir.cc | 16 +- src/script/printer/ir/misc.cc | 4 +- src/script/printer/ir_docsifier.cc | 4 +- src/script/printer/utils.h | 5 +- src/target/codegen.cc | 5 +- src/target/cuda/intrin_rule_cuda.cc | 8 +- src/target/llvm/codegen_llvm.cc | 3 +- src/target/llvm/codegen_llvm.h | 2 +- src/target/metal/intrin_rule_metal.cc | 6 +- src/target/source/codegen_c.h | 2 +- src/target/virtual_device.cc | 8 +- src/target/webgpu/intrin_rule_webgpu.cc | 6 +- src/tirx/analysis/verify_memory.cc | 1 + src/tirx/analysis/verify_well_formed.cc | 33 ++-- src/tirx/ir/stmt.cc | 2 +- src/tirx/ir/tir_visitor_with_path.cc | 85 +++++---- src/tirx/op/builtin.cc | 10 +- src/tirx/op/op.cc | 5 +- src/tirx/op/runtime.cc | 4 +- src/tirx/script/builder/ir.cc | 1 + src/tirx/script/printer/block.cc | 20 +-- src/tirx/script/printer/buffer.cc | 40 ++--- src/tirx/script/printer/expr.cc | 51 +++--- src/tirx/script/printer/for_loop.cc | 2 +- src/tirx/script/printer/function.cc | 18 +- src/tirx/script/printer/ir.cc | 14 +- src/tirx/script/printer/stmt.cc | 34 ++-- src/tirx/script/printer/utils.h | 8 +- src/tirx/transform/ir_utils.cc | 4 +- src/tirx/transform/ir_utils.h | 2 +- 103 files changed, 559 insertions(+), 570 deletions(-) diff --git a/include/tvm/arith/bound.h b/include/tvm/arith/bound.h index 7ae36fb289f0..a286103bed8c 100644 --- a/include/tvm/arith/bound.h +++ b/include/tvm/arith/bound.h @@ -33,7 +33,6 @@ namespace tvm { namespace arith { -using tirx::Region; using tirx::Stmt; using tirx::Var; using tirx::VarNode; @@ -77,8 +76,8 @@ IntSet DeduceBound(PrimExpr v, PrimExpr cond, * \param consider_stores If stores are considered. * \return The domain that covers all the calls or provides within the given statement. */ -Region DomainTouched(const Stmt& body, const tirx::Buffer& buffer, bool consider_loads, - bool consider_stores); +ffi::Array DomainTouched(const Stmt& body, const tirx::Buffer& buffer, bool consider_loads, + bool consider_stores); } // namespace arith } // namespace tvm diff --git a/include/tvm/ir/global_info.h b/include/tvm/ir/global_info.h index 3533b7868732..1fb87654d0c1 100644 --- a/include/tvm/ir/global_info.h +++ b/include/tvm/ir/global_info.h @@ -31,11 +31,6 @@ namespace tvm { -/*! - * \brief Abstract label for an area of memory. - */ -using MemoryScope = ffi::String; - /*! * \brief GlobalInfo are globally static object that are referred by the IR itself. * Base node for all global info that can appear in the IR @@ -67,7 +62,7 @@ class VDeviceNode : public GlobalInfoNode { * differentiate between distinct devices with same Target, such as multiple GPUs. */ int vdevice_id; - MemoryScope memory_scope; + ffi::String memory_scope; static void RegisterReflection() { namespace refl = tvm::ffi::reflection; @@ -86,7 +81,7 @@ class VDeviceNode : public GlobalInfoNode { */ class VDevice : public GlobalInfo { public: - TVM_DLL explicit VDevice(Target tgt, int dev_id, MemoryScope mem_scope); + TVM_DLL explicit VDevice(Target tgt, int dev_id, ffi::String mem_scope); TVM_FFI_DEFINE_OBJECT_REF_METHODS_NULLABLE(VDevice, GlobalInfo, VDeviceNode); }; diff --git a/include/tvm/relax/attrs/op.h b/include/tvm/relax/attrs/op.h index 54640901ff53..e7dc64f8005e 100644 --- a/include/tvm/relax/attrs/op.h +++ b/include/tvm/relax/attrs/op.h @@ -104,7 +104,7 @@ struct ToVDeviceAttrs : public AttrsNodeReflAdapter { struct HintOnDeviceAttrs : public AttrsNodeReflAdapter { int32_t device_type; int32_t index; - MemoryScope memory_scope; + ffi::String memory_scope; static void RegisterReflection() { namespace refl = tvm::ffi::reflection; diff --git a/include/tvm/relax/op_attr_types.h b/include/tvm/relax/op_attr_types.h index 2e686035b20c..85d0c333c8ae 100644 --- a/include/tvm/relax/op_attr_types.h +++ b/include/tvm/relax/op_attr_types.h @@ -61,12 +61,6 @@ enum OpPatternKind { */ using FInferStructInfo = ffi::TypedFunction; -/*! - * \brief Packed function implementation for operators. The relax operator will be lowered to - * this packed function call during codegen. - */ -using FCallPacked = ffi::String; - /*! * \brief The function type of a normalization function. * diff --git a/include/tvm/runtime/data_type.h b/include/tvm/runtime/data_type.h index 67fe50350d2f..9f230cac824e 100644 --- a/include/tvm/runtime/data_type.h +++ b/include/tvm/runtime/data_type.h @@ -26,8 +26,8 @@ #include #include +#include #include -#include #include #include @@ -36,8 +36,6 @@ namespace tvm { namespace runtime { -using tvm_index_t = ffi::Shape::index_type; - /*! * \brief Runtime primitive data type. * @@ -404,10 +402,10 @@ class DataType { * \return The type of TVM shape index. */ static DataType ShapeIndex() { - if (std::is_signed::value) { - return DataType::Int(sizeof(tvm_index_t) * 8); + if (std::is_signed::value) { + return DataType::Int(sizeof(ffi::Shape::index_type) * 8); } else { - return DataType::UInt(sizeof(tvm_index_t) * 8); + return DataType::UInt(sizeof(ffi::Shape::index_type) * 8); } } @@ -451,9 +449,6 @@ inline bool TypeEqual(DLDataType lhs, DLDataType rhs) { return lhs.code == rhs.code && lhs.bits == rhs.bits && lhs.lanes == rhs.lanes; } -using ffi::DLDataTypeToString; -using ffi::StringToDLDataType; - inline std::ostream& operator<<(std::ostream& os, const DataType& dtype) { // NOLINT(*) return os << dtype.operator DLDataType(); } diff --git a/include/tvm/runtime/logging.h b/include/tvm/runtime/logging.h index d051a01da4c4..68718acc7cb9 100644 --- a/include/tvm/runtime/logging.h +++ b/include/tvm/runtime/logging.h @@ -60,8 +60,6 @@ namespace tvm { namespace runtime { -using ffi::EnvErrorAlreadySet; - /*! \brief Internal implementation */ namespace detail { // Provide support for customized logging. diff --git a/include/tvm/runtime/tensor.h b/include/tvm/runtime/tensor.h index f0ea3508bc91..33a78a48d6ae 100644 --- a/include/tvm/runtime/tensor.h +++ b/include/tvm/runtime/tensor.h @@ -42,17 +42,12 @@ namespace tvm { namespace runtime { -using ffi::GetDataSize; -using ffi::IsAligned; -using ffi::IsContiguous; - /*! * \brief Managed Tensor. * The array is backed by reference counted blocks. */ class Tensor : public tvm::ffi::Tensor { public: - using Container = ffi::TensorObj; Tensor() = default; /*! * \brief constructor. diff --git a/include/tvm/script/printer/doc.h b/include/tvm/script/printer/doc.h index 0430c8d8f172..e3d32cb50335 100644 --- a/include/tvm/script/printer/doc.h +++ b/include/tvm/script/printer/doc.h @@ -31,8 +31,6 @@ namespace tvm { namespace script { namespace printer { -using AccessPath = ffi::reflection::AccessPath; - // Forward declaration class Doc; @@ -260,14 +258,15 @@ class LiteralDocNode : public ExprDocNode { */ class LiteralDoc : public ExprDoc { protected: - explicit LiteralDoc(ffi::Any value, const ffi::Optional& object_path); + explicit LiteralDoc(ffi::Any value, + const ffi::Optional& object_path); public: /*! * \brief Create a LiteralDoc to represent None/null/empty value. * \param p The object path */ - static LiteralDoc None(const ffi::Optional& p) { + static LiteralDoc None(const ffi::Optional& p) { return LiteralDoc(ffi::Any(nullptr), p); } /*! @@ -275,7 +274,7 @@ class LiteralDoc : public ExprDoc { * \param v The integer value. * \param p The object path */ - static LiteralDoc Int(int64_t v, const ffi::Optional& p) { + static LiteralDoc Int(int64_t v, const ffi::Optional& p) { return LiteralDoc(IntImm(DataType::Int(64), v), p); } /*! @@ -283,7 +282,7 @@ class LiteralDoc : public ExprDoc { * \param v The boolean value. * \param p The object path */ - static LiteralDoc Boolean(bool v, const ffi::Optional& p) { + static LiteralDoc Boolean(bool v, const ffi::Optional& p) { return LiteralDoc(IntImm(DataType::Bool(), v), p); } /*! @@ -291,7 +290,7 @@ class LiteralDoc : public ExprDoc { * \param v The float value. * \param p The object path */ - static LiteralDoc Float(double v, const ffi::Optional& p) { + static LiteralDoc Float(double v, const ffi::Optional& p) { return LiteralDoc(FloatImm(DataType::Float(64), v), p); } /*! @@ -299,7 +298,7 @@ class LiteralDoc : public ExprDoc { * \param v The string value. * \param p The object path */ - static LiteralDoc Str(const ffi::String& v, const ffi::Optional& p) { + static LiteralDoc Str(const ffi::String& v, const ffi::Optional& p) { return LiteralDoc(v, p); } /*! @@ -307,8 +306,9 @@ class LiteralDoc : public ExprDoc { * \param v The string value. * \param p The object path */ - static LiteralDoc DataType(const runtime::DataType& v, const ffi::Optional& p) { - std::string dtype = v.is_void() ? "void" : runtime::DLDataTypeToString(v); + static LiteralDoc DataType(const runtime::DataType& v, + const ffi::Optional& p) { + std::string dtype = v.is_void() ? "void" : ffi::DLDataTypeToString(v); return LiteralDoc::Str(dtype, p); } /*! @@ -316,7 +316,7 @@ class LiteralDoc : public ExprDoc { * \param v The device. * \param p The object path */ - static LiteralDoc Device(const DLDevice& v, const ffi::Optional& p) { + static LiteralDoc Device(const DLDevice& v, const ffi::Optional& p) { std::ostringstream os; runtime::operator<<(os, v); return LiteralDoc::Str(os.str(), p); diff --git a/include/tvm/script/printer/ir_docsifier.h b/include/tvm/script/printer/ir_docsifier.h index e49d4f8a1cc0..7c0e082fbc5b 100644 --- a/include/tvm/script/printer/ir_docsifier.h +++ b/include/tvm/script/printer/ir_docsifier.h @@ -36,8 +36,6 @@ namespace tvm { namespace script { namespace printer { -using AccessPath = ffi::reflection::AccessPath; - //////////////////////// Frame //////////////////////// class IRDocsifier; @@ -237,7 +235,7 @@ class IRDocsifierNode : public ffi::Object { * \return The Doc object. */ template - inline TDoc AsDoc(const Any& obj, const AccessPath& path) const; + inline TDoc AsDoc(const Any& obj, const ffi::reflection::AccessPath& path) const; }; /*! @@ -245,7 +243,7 @@ class IRDocsifierNode : public ffi::Object { */ class IRDocsifier : public ffi::ObjectRef { public: - using FType = IRDocsifierFunctor; + using FType = IRDocsifierFunctor; /*! \brief Create a IRDocsifier. */ explicit IRDocsifier(const PrinterConfig& cfg); /*! \brief The registration table for IRDocsifier. */ @@ -273,7 +271,8 @@ inline void FrameNode::ExitWithScope() { } template -inline static void AddDocDecoration(const Doc& d, const ffi::ObjectRef& obj, const AccessPath& path, +inline static void AddDocDecoration(const Doc& d, const ffi::ObjectRef& obj, + const ffi::reflection::AccessPath& path, const PrinterConfig& cfg) { if (cfg->obj_to_annotate.count(obj)) { if (const auto* stmt = d.as()) { @@ -293,7 +292,7 @@ inline static void AddDocDecoration(const Doc& d, const ffi::ObjectRef& obj, con } } for (const auto& pair : cfg->path_to_annotate) { - AccessPath p = pair.first; + ffi::reflection::AccessPath p = pair.first; ffi::String attn = pair.second; if (p->IsPrefixOf(path) && path->IsPrefixOf(p)) { if (const auto* stmt = d.as()) { @@ -311,7 +310,8 @@ inline static void AddDocDecoration(const Doc& d, const ffi::ObjectRef& obj, con } template -inline TDoc IRDocsifierNode::AsDoc(const Any& value, const AccessPath& path) const { +inline TDoc IRDocsifierNode::AsDoc(const Any& value, + const ffi::reflection::AccessPath& path) const { switch (value.type_index()) { case ffi::TypeIndex::kTVMFFINone: return Downcast(LiteralDoc::None(path)); diff --git a/include/tvm/target/codegen.h b/include/tvm/target/codegen.h index 32baf16d4a3a..0274592d6604 100644 --- a/include/tvm/target/codegen.h +++ b/include/tvm/target/codegen.h @@ -34,10 +34,6 @@ namespace tvm { /*! \brief namespace for target translation and codegen. */ namespace codegen { -// use packed function from runtime. -using ffi::Any; -using ffi::Function; -using ffi::PackedArgs; /*! * \brief Build a module from array of lowered function. diff --git a/include/tvm/target/virtual_device.h b/include/tvm/target/virtual_device.h index 7829bb61d4ad..bf00144b0f66 100644 --- a/include/tvm/target/virtual_device.h +++ b/include/tvm/target/virtual_device.h @@ -35,15 +35,6 @@ namespace tvm { -/*! - * Abstract label for an area of memory. - * - * Currently uninterpreted and arbitrary. Likely to be replaced by a structured representation - * of a memory pool in the future. Please try to use this alias instead of ffi::String to aid future - * code migration. - */ -using MemoryScope = ffi::String; - // NOTE: cannot use enum as they are out of bound of the original enum // and results in an undefined behavior // A 'null' device type, does not correspond to any DLDeviceType enum. @@ -67,7 +58,7 @@ constexpr int kInvalidDeviceType = -1; * See "Virtual Devices" below. * - A \p target (\p Target) describing how to compile code for the intended device. May be null * if unconstrained. - * - A \p memory_scope (\p MemoryScope, which is currently just \p String) describing which memory + * - A \p memory_scope (\p ffi::String) describing which memory * area is to be used to hold data. May be "" if unconstrained. See "Memory Scopes and Devices" * below. * @@ -209,7 +200,7 @@ class VirtualDeviceNode : public AttrsNodeReflAdapter { * * Empty denotes unconstrained. */ - MemoryScope memory_scope; + ffi::String memory_scope; /*! * \brief Returns true if virtual device is 'fully unconstrained', ie no target/device type, @@ -279,7 +270,7 @@ class VirtualDevice : public ffi::ObjectRef { */ TVM_DLL explicit VirtualDevice(int device_type_int = kInvalidDeviceType, int virtual_device_id = -1, Target target = {}, - MemoryScope memory_scope = {}); + ffi::String memory_scope = {}); /*! \brief Returns the unique fully unconstrained \p VirtualDevice. */ static VirtualDevice FullyUnconstrained(); @@ -316,13 +307,13 @@ class VirtualDevice : public ffi::ObjectRef { } /*! \brief Returns the \p VirtualDevice for \p memory_scope alone. */ - static VirtualDevice ForMemoryScope(MemoryScope memory_scope) { + static VirtualDevice ForMemoryScope(ffi::String memory_scope) { return VirtualDevice(kInvalidDeviceType, -1, {}, std::move(memory_scope)); } /*! \brief Returns the \p VirtualDevice for \p device, \p target and \p memory_scope. */ TVM_DLL static VirtualDevice ForDeviceTargetAndMemoryScope(const Device& device, Target target, - MemoryScope memory_scope) { + ffi::String memory_scope) { return VirtualDevice(device.device_type, device.device_id, std::move(target), std::move(memory_scope)); } @@ -358,7 +349,7 @@ class TVM_DLL VirtualDeviceCache { public: /*! \brief Returns the unique \p VirtualDevice representing given fields. */ VirtualDevice Make(int device_type = kInvalidDeviceType, int virtual_device_id = -1, - Target target = {}, MemoryScope memory_scope = {}); + Target target = {}, ffi::String memory_scope = {}); /*! * \brief Returns the unique \p VirtualDevice structurally equal to the given \p virtual_device. diff --git a/include/tvm/tirx/op.h b/include/tvm/tirx/op.h index c953f12e3870..3a59ada864e6 100644 --- a/include/tvm/tirx/op.h +++ b/include/tvm/tirx/op.h @@ -42,7 +42,7 @@ namespace tvm { #define TVM_TIR_REGISTER_OP(OpName) \ - TVM_REGISTER_OP("tirx." OpName).set_attr("TScriptPrinterName", OpName) + TVM_REGISTER_OP("tirx." OpName).set_attr("TScriptPrinterName", OpName) // Most common operators can be overloaded by argument type(PrimExpr). // So we put them under the root namespace. diff --git a/include/tvm/tirx/op_attr_types.h b/include/tvm/tirx/op_attr_types.h index 9d0173bfd49f..2d9aef4b257d 100644 --- a/include/tvm/tirx/op_attr_types.h +++ b/include/tvm/tirx/op_attr_types.h @@ -36,11 +36,6 @@ namespace tvm { namespace tirx { -/*! - * \brief Global symbol of the op after lowering. - */ -using TGlobalSymbol = ffi::String; - /*! * \brief Whether the op is overloaded for vector form. */ @@ -56,11 +51,6 @@ using FLowerIntrinsic = ffi::TypedFunction; */ using FLegalize = ffi::TypedFunction; -/*! - * \brief The operator's name in TVMScript printer - */ -using TScriptPrinterName = ffi::String; - /*! * \brief Specifies that TVMScript printer prints the dtype as the first/last argument. If not specified, dtype will not be printed. diff --git a/include/tvm/tirx/var.h b/include/tvm/tirx/var.h index c38908d56d7d..3d84fb00bc0e 100644 --- a/include/tvm/tirx/var.h +++ b/include/tvm/tirx/var.h @@ -173,7 +173,7 @@ class SizeVar : public Var { using ContainerType = SizeVarNode; }; -using Region = ffi::Array; +// NOTE: Region was an alias for ffi::Array; use ffi::Array directly. /*! * \brief Type of iteration variable. diff --git a/src/arith/domain_touched.cc b/src/arith/domain_touched.cc index b218eb9e57e9..12ca88a60ca3 100644 --- a/src/arith/domain_touched.cc +++ b/src/arith/domain_touched.cc @@ -23,6 +23,7 @@ */ #include #include +#include #include #include #include @@ -67,8 +68,8 @@ class BufferTouchedDomain final : public IRVisitorWithAnalyzer { return buffer_access_map_; } - Region FindUnion(const Buffer& buffer, bool consider_loads, bool consider_stores) { - Region ret; + ffi::Array FindUnion(const Buffer& buffer, bool consider_loads, bool consider_stores) { + ffi::Array ret; auto kv = buffer_access_map_.find(buffer.get()); if (kv == buffer_access_map_.end()) { LOG(WARNING) << "[arith::BufferDomainTouched] " @@ -132,8 +133,8 @@ class BufferTouchedDomain final : public IRVisitorWithAnalyzer { std::unordered_map buffer_access_map_; }; -Region DomainTouched(const Stmt& stmt, const Buffer& buffer, bool consider_loads, - bool consider_stores) { +ffi::Array DomainTouched(const Stmt& stmt, const Buffer& buffer, bool consider_loads, + bool consider_stores) { return BufferTouchedDomain(stmt).FindUnion(buffer, consider_loads, consider_stores); } diff --git a/src/ir/global_info.cc b/src/ir/global_info.cc index d8bba04c5138..4bb37ae1b062 100644 --- a/src/ir/global_info.cc +++ b/src/ir/global_info.cc @@ -39,7 +39,7 @@ TVM_FFI_STATIC_INIT_BLOCK() { }); } -VDevice::VDevice(Target tgt, int dev_id, MemoryScope mem_scope) { +VDevice::VDevice(Target tgt, int dev_id, ffi::String mem_scope) { ffi::ObjectPtr n = ffi::make_object(); n->target = std::move(tgt); n->vdevice_id = std::move(dev_id); @@ -49,7 +49,7 @@ VDevice::VDevice(Target tgt, int dev_id, MemoryScope mem_scope) { TVM_FFI_STATIC_INIT_BLOCK() { namespace refl = tvm::ffi::reflection; - refl::GlobalDef().def("ir.VDevice", [](Target tgt, int dev_id, MemoryScope mem_scope) { + refl::GlobalDef().def("ir.VDevice", [](Target tgt, int dev_id, ffi::String mem_scope) { return VDevice(tgt, dev_id, mem_scope); }); } diff --git a/src/ir/script_printer.cc b/src/ir/script_printer.cc index ea0c3d031eae..dc1f035f5cb3 100644 --- a/src/ir/script_printer.cc +++ b/src/ir/script_printer.cc @@ -28,8 +28,6 @@ namespace tvm { -using AccessPath = ffi::reflection::AccessPath; - TVM_FFI_STATIC_INIT_BLOCK() { PrinterConfigNode::RegisterReflection(); } TVMScriptPrinter::FType& TVMScriptPrinter::vtable() { @@ -94,11 +92,13 @@ PrinterConfig::PrinterConfig(ffi::Map config_dict) { } if (auto v = config_dict.Get("path_to_underline")) { n->path_to_underline = - Downcast>>(v).value_or(ffi::Array()); + Downcast>>(v).value_or( + ffi::Array()); } if (auto v = config_dict.Get("path_to_annotate")) { - n->path_to_annotate = Downcast>>(v).value_or( - ffi::Map()); + n->path_to_annotate = + Downcast>>(v).value_or( + ffi::Map()); } if (auto v = config_dict.Get("obj_to_underline")) { n->obj_to_underline = Downcast>>(v).value_or( diff --git a/src/ir/structural_hash.cc b/src/ir/structural_hash.cc index e1903871e175..9f33c2f50a03 100644 --- a/src/ir/structural_hash.cc +++ b/src/ir/structural_hash.cc @@ -60,9 +60,9 @@ TVM_FFI_STATIC_INIT_BLOCK() { return rtmod; }); - refl::TypeAttrDef() + refl::TypeAttrDef() .def("__data_to_json__", - [](const runtime::Tensor::Container* node) { + [](const ffi::TensorObj* node) { std::string result; support::BytesOutStream mstrm(&result); support::Base64OutStream b64strm(&mstrm); diff --git a/src/relax/backend/contrib/codegen_json/codegen_json.h b/src/relax/backend/contrib/codegen_json/codegen_json.h index bb0f80b82fcd..d7874eb84679 100644 --- a/src/relax/backend/contrib/codegen_json/codegen_json.h +++ b/src/relax/backend/contrib/codegen_json/codegen_json.h @@ -91,7 +91,7 @@ class OpAttrExtractor { void Visit(const char* key, DataType* value) { if (!value->is_void()) { - SetNodeAttr(key, ffi::String(runtime::DLDataTypeToString(*value))); + SetNodeAttr(key, ffi::String(ffi::DLDataTypeToString(*value))); } else { SetNodeAttr(key, ffi::String("")); } diff --git a/src/relax/backend/contrib/nnapi/codegen.cc b/src/relax/backend/contrib/nnapi/codegen.cc index 9d85d5ef82d1..757570e69ad5 100644 --- a/src/relax/backend/contrib/nnapi/codegen.cc +++ b/src/relax/backend/contrib/nnapi/codegen.cc @@ -67,7 +67,7 @@ class CollectFromCompositeFunctionBody : public ExprVisitor { void SetAstypeAttribute(const CallNode* call_node) { const auto* astype_attrs = call_node->attrs.as(); TVM_FFI_ICHECK(astype_attrs); - node_->SetAttr("astype_dtype", ffi::String(runtime::DLDataTypeToString(astype_attrs->dtype))); + node_->SetAttr("astype_dtype", ffi::String(ffi::DLDataTypeToString(astype_attrs->dtype))); } void SetMeanAttribute(const CallNode* call_node) { diff --git a/src/relax/backend/vm/codegen_vm.cc b/src/relax/backend/vm/codegen_vm.cc index e9cb175fdcc7..108f746f4eed 100644 --- a/src/relax/backend/vm/codegen_vm.cc +++ b/src/relax/backend/vm/codegen_vm.cc @@ -405,7 +405,7 @@ class CodeGenVM : public ExprFunctor { } // Emits call to packed function `name` with arguments copied over from `call_node` args - void EmitPackedFuncCall(const Call& call_node, const FCallPacked& name, RegName dst_reg) { + void EmitPackedFuncCall(const Call& call_node, const ffi::String& name, RegName dst_reg) { std::vector args = VisitArray(call_node->args); builder_->EmitCall(name, args, dst_reg); } diff --git a/src/relax/op/op.cc b/src/relax/op/op.cc index 1d9e48ca6381..03770e3a1cfc 100644 --- a/src/relax/op/op.cc +++ b/src/relax/op/op.cc @@ -977,7 +977,7 @@ TVM_REGISTER_OP("relax.print") "The first value is Python-style format string to use to print. The others " "are values to print") .set_attr("FInferStructInfo", ReturnVoidStructInfo) - .set_attr("FCallPacked", "relax.run.print") + .set_attr("FCallPacked", "relax.run.print") .set_attr("FPurity", Bool(false)); Expr MakePrint(ffi::Array vals, StringImm format) { @@ -1023,7 +1023,7 @@ TVM_REGISTER_OP("relax.assert_op") "Python-style format string to use for displaying an error message, if the " "assert fails. The others are used as format arguments if there is an error.") .set_attr("FInferStructInfo", InferAssertStructInfo) - .set_attr("FCallPacked", "relax.run.assert_op") + .set_attr("FCallPacked", "relax.run.assert_op") .set_attr("FPurity", Bool(false)); Expr MakeAssertOp(Expr condition, ffi::Array vals, StringImm format) { @@ -1204,7 +1204,7 @@ TVM_REGISTER_OP("relax.shape_to_tensor") .set_num_inputs(1) .add_argument("input", "Expr", "The input expression") .set_attr("FInferStructInfo", ReturnShapeToTensorStructInfo) - .set_attr("FCallPacked", "relax.run.shape_to_tensor") + .set_attr("FCallPacked", "relax.run.shape_to_tensor") .set_attr("FPurity", Bool(true)); Expr MakeShapeToTensor(Expr expr) { diff --git a/src/relax/op/tensor/set.cc b/src/relax/op/tensor/set.cc index 183c254fb8fd..b1e23edb7340 100644 --- a/src/relax/op/tensor/set.cc +++ b/src/relax/op/tensor/set.cc @@ -166,7 +166,7 @@ TVM_REGISTER_OP("relax.unique") "flattened input " "are returned.") .set_attr("FInferStructInfo", InferStructInfoUnique) - .set_attr("FCallPacked", "relax.run.unique") + .set_attr("FCallPacked", "relax.run.unique") .set_attr("FPurity", Bool(true)); /* relax.nonzero */ @@ -189,7 +189,7 @@ TVM_REGISTER_OP("relax.nonzero") .set_num_inputs(1) .add_argument("x", "Tensor", "The input tensor") .set_attr("FInferStructInfo", InferStructInfoNonzero) - .set_attr("FCallPacked", "relax.run.nonzero") + .set_attr("FCallPacked", "relax.run.nonzero") .set_attr("FPurity", Bool(true)); } // namespace relax diff --git a/src/relax/script/printer/binding.cc b/src/relax/script/printer/binding.cc index ec158a0b6773..da8e6ae8de01 100644 --- a/src/relax/script/printer/binding.cc +++ b/src/relax/script/printer/binding.cc @@ -24,7 +24,8 @@ namespace tvm { namespace script { namespace printer { -IfDoc PrintIfExpr(const relax::If& n, const AccessPath& n_p, const IRDocsifier& d, // +IfDoc PrintIfExpr(const relax::If& n, const ffi::reflection::AccessPath& n_p, + const IRDocsifier& d, // const ffi::Optional& var, const ffi::Optional& ann) { using relax::SeqExpr; ExprDoc cond = d->AsDoc(n->cond, n_p->Attr("cond")); @@ -43,7 +44,7 @@ IfDoc PrintIfExpr(const relax::If& n, const AccessPath& n_p, const IRDocsifier& TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) .set_dispatch( - "", [](relax::MatchCast n, AccessPath n_p, IRDocsifier d) -> Doc { + "", [](relax::MatchCast n, ffi::reflection::AccessPath n_p, IRDocsifier d) -> Doc { using relax::StructInfo; using relax::MatchStructInfo; ffi::Optional ann = std::nullopt; @@ -59,7 +60,7 @@ TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) .set_dispatch( // - "", [](relax::VarBinding n, AccessPath n_p, IRDocsifier d) -> Doc { + "", [](relax::VarBinding n, ffi::reflection::AccessPath n_p, IRDocsifier d) -> Doc { if (const auto if_ = n->value.as()) { ffi::Optional ann = StructInfoAsAnn(n->var, n_p->Attr("var"), d, n->value); ExprDoc lhs = DefineVar(n->var, d->frames.back(), d); @@ -84,9 +85,11 @@ TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) }); TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) - .set_dispatch("", [](relax::If n, AccessPath n_p, IRDocsifier d) -> Doc { - return PrintIfExpr(n, n_p, d, std::nullopt, std::nullopt); - }); + .set_dispatch("", + [](relax::If n, ffi::reflection::AccessPath n_p, + IRDocsifier d) -> Doc { + return PrintIfExpr(n, n_p, d, std::nullopt, std::nullopt); + }); TVM_REGISTER_SCRIPT_AS_REPR(relax::MatchCastNode, ReprPrintRelax); TVM_REGISTER_SCRIPT_AS_REPR(relax::VarBindingNode, ReprPrintRelax); diff --git a/src/relax/script/printer/call.cc b/src/relax/script/printer/call.cc index 262be66e924c..af4cb54f6848 100644 --- a/src/relax/script/printer/call.cc +++ b/src/relax/script/printer/call.cc @@ -29,8 +29,8 @@ namespace printer { class AttrPrinter { public: - explicit AttrPrinter(AccessPath p, const IRDocsifier& d, ffi::Array* keys, - ffi::Array* values) + explicit AttrPrinter(ffi::reflection::AccessPath p, const IRDocsifier& d, + ffi::Array* keys, ffi::Array* values) : p(std::move(p)), d(d), keys(keys), values(values) {} void operator()(const tvm::Attrs& attrs) { @@ -54,13 +54,14 @@ class AttrPrinter { } } - AccessPath p; + ffi::reflection::AccessPath p; const IRDocsifier& d; ffi::Array* keys; ffi::Array* values; }; -ExprDoc PrintCallee(const relax::Expr& n, const AccessPath& n_p, const IRDocsifier& d) { +ExprDoc PrintCallee(const relax::Expr& n, const ffi::reflection::AccessPath& n_p, + const IRDocsifier& d) { // TODO(@junrushao): handle callee better if (const auto* ext = n.as()) { return LiteralDoc::Str(ext->global_symbol, n_p); @@ -69,7 +70,8 @@ ExprDoc PrintCallee(const relax::Expr& n, const AccessPath& n_p, const IRDocsifi } } -ffi::Optional PrintCallTIRDPSPacked(const relax::Call& n, const AccessPath& n_p, +ffi::Optional PrintCallTIRDPSPacked(const relax::Call& n, + const ffi::reflection::AccessPath& n_p, const IRDocsifier& d) { static const Op& call_tir_op = Op::Get("relax.call_tir"); static const Op& call_tir_inplace_op = Op::Get("relax.call_tir_inplace"); @@ -92,12 +94,12 @@ ffi::Optional PrintCallTIRDPSPacked(const relax::Call& n, const AccessP args.push_back(d->AsDoc(n->args[1], n_p->Attr("args")->ArrayItem(1))); // Step 3. Print n->sinfo_args, the output struct info relax::StructInfo o_sinfo = n->sinfo_args[0]; - AccessPath o_sinfo_p = n_p->Attr("sinfo_args")->ArrayItem(0); + ffi::reflection::AccessPath o_sinfo_p = n_p->Attr("sinfo_args")->ArrayItem(0); bool is_dtensor = false; kwargs_keys.push_back("out_sinfo"); if (const auto* o = o_sinfo.as()) { ffi::Array fields; - AccessPath fields_p = o_sinfo_p->Attr("fields"); + ffi::reflection::AccessPath fields_p = o_sinfo_p->Attr("fields"); for (int i = 0, l = o->fields.size(); i < l; ++i) { if (o->fields[i].as()) { is_dtensor = true; @@ -160,7 +162,7 @@ ffi::Optional PrintCallTIRDPSPacked(const relax::Call& n, const AccessP } } -ffi::Optional PrintAssertOp(const relax::Call& n, const AccessPath& n_p, +ffi::Optional PrintAssertOp(const relax::Call& n, const ffi::reflection::AccessPath& n_p, const IRDocsifier& d) { static const Op& assert_op = Op::Get("relax.assert_op"); if (!n->op.same_as(assert_op)) { @@ -180,7 +182,8 @@ ffi::Optional PrintAssertOp(const relax::Call& n, const AccessPath& n_p return Relax(d, "assert_op")->Call(args, {"format"}, {second_arg}); } -ffi::Optional PrintHintOnDevice(const relax::Call& n, const AccessPath& n_p, +ffi::Optional PrintHintOnDevice(const relax::Call& n, + const ffi::reflection::AccessPath& n_p, const IRDocsifier& d) { static const Op& hint_on_device_op = Op::Get("relax.hint_on_device"); if (!n->op.same_as(hint_on_device_op)) { @@ -203,7 +206,7 @@ ffi::Optional PrintHintOnDevice(const relax::Call& n, const AccessPath& return Relax(d, "hint_on_device")->Call(args); } -ffi::Optional PrintToVDevice(const relax::Call& n, const AccessPath& n_p, +ffi::Optional PrintToVDevice(const relax::Call& n, const ffi::reflection::AccessPath& n_p, const IRDocsifier& d) { static const Op& to_vdevice_op = Op::Get("relax.to_vdevice"); if (!n->op.same_as(to_vdevice_op)) { @@ -227,7 +230,7 @@ ffi::Optional PrintToVDevice(const relax::Call& n, const AccessPath& n_ return Relax(d, "to_vdevice")->Call(args, kwargs_keys, kwargs_values); } -ffi::Optional PrintRelaxPrint(const relax::Call& n, const AccessPath& n_p, +ffi::Optional PrintRelaxPrint(const relax::Call& n, const ffi::reflection::AccessPath& n_p, const IRDocsifier& d) { static const Op& print_op = Op::Get("relax.print"); if (!n->op.same_as(print_op)) { @@ -248,7 +251,7 @@ ffi::Optional PrintRelaxPrint(const relax::Call& n, const AccessPath& n TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) .set_dispatch( // - "", [](relax::Call n, AccessPath n_p, IRDocsifier d) -> Doc { + "", [](relax::Call n, ffi::reflection::AccessPath n_p, IRDocsifier d) -> Doc { // Special case: call_tir, call_dps_packed, call_tir_with_grad if (ffi::Optional doc = PrintCallTIRDPSPacked(n, n_p, d)) { return doc.value(); @@ -322,7 +325,7 @@ TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) } // Step 4. Print type_args if (n->sinfo_args.size() > 0) { - AccessPath sinfo_args_p = n_p->Attr("sinfo_args"); + ffi::reflection::AccessPath sinfo_args_p = n_p->Attr("sinfo_args"); ffi::Array sinfo_args; for (int i = 0, l = n->sinfo_args.size(); i < l; ++i) { sinfo_args.push_back(d->AsDoc(n->sinfo_args[i], sinfo_args_p->ArrayItem(i))); diff --git a/src/relax/script/printer/distributed.cc b/src/relax/script/printer/distributed.cc index 0a67b55af89f..98a96b84ebbe 100644 --- a/src/relax/script/printer/distributed.cc +++ b/src/relax/script/printer/distributed.cc @@ -30,14 +30,17 @@ namespace printer { // distributed::Placement TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) .set_dispatch("", - [](relax::distributed::Placement n, AccessPath n_p, + [](relax::distributed::Placement n, + ffi::reflection::AccessPath n_p, IRDocsifier d) -> Doc { return d->AsDoc(n->ToString(), n_p); }); TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) .set_dispatch( - "", [](relax::distributed::DTensorStructInfo n, AccessPath n_p, IRDocsifier d) -> Doc { + "", + [](relax::distributed::DTensorStructInfo n, ffi::reflection::AccessPath n_p, + IRDocsifier d) -> Doc { ffi::Array args; ffi::Array kwargs_keys; ffi::Array kwargs_values; @@ -46,7 +49,7 @@ TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) // Need to dig into ShapeExpr to preserve the `R.shape` prefix if (const auto* shape = n->tensor_sinfo->shape.value().as()) { auto shape_expr = ffi::GetRef(shape); - AccessPath shape_p = n_p->Attr("shape")->Attr("values"); + ffi::reflection::AccessPath shape_p = n_p->Attr("shape")->Attr("values"); ffi::Array shape_docs; for (int i = 0, ndim = shape_expr->values.size(); i < ndim; ++i) { shape_docs.push_back( @@ -91,7 +94,9 @@ TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) .set_dispatch( - "", [](relax::distributed::DeviceMesh n, AccessPath n_p, IRDocsifier d) -> Doc { + "", + [](relax::distributed::DeviceMesh n, ffi::reflection::AccessPath n_p, + IRDocsifier d) -> Doc { bool has_relax_frame = false; const IRFrameNode* f = nullptr; for (const Frame& frame : d->frames) { diff --git a/src/relax/script/printer/expr.cc b/src/relax/script/printer/expr.cc index c8a813b8d5ab..b6d750bc2df3 100644 --- a/src/relax/script/printer/expr.cc +++ b/src/relax/script/printer/expr.cc @@ -30,32 +30,32 @@ namespace printer { TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) .set_dispatch( // - "", [](relax::PrimValue n, AccessPath n_p, IRDocsifier d) -> Doc { + "", [](relax::PrimValue n, ffi::reflection::AccessPath n_p, IRDocsifier d) -> Doc { // TODO(@junrushao): float numbers return Relax(d, "prim_value")->Call({d->AsDoc(n->value, n_p->Attr("value"))}); }); TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) .set_dispatch( // - "", [](relax::StringImm n, AccessPath n_p, IRDocsifier d) -> Doc { + "", [](relax::StringImm n, ffi::reflection::AccessPath n_p, IRDocsifier d) -> Doc { return Relax(d, "str")->Call({LiteralDoc::Str(n->value, n_p->Attr("value"))}); }); TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) .set_dispatch( // - "", [](relax::DataTypeImm n, AccessPath n_p, IRDocsifier d) -> Doc { + "", [](relax::DataTypeImm n, ffi::reflection::AccessPath n_p, IRDocsifier d) -> Doc { return Relax(d, "dtype")->Call({LiteralDoc::DataType(n->value, n_p->Attr("value"))}); }); TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) .set_dispatch( // - "", [](relax::Tuple n, AccessPath n_p, IRDocsifier d) -> Doc { + "", [](relax::Tuple n, ffi::reflection::AccessPath n_p, IRDocsifier d) -> Doc { // TODO(@junrushao): revisit tuple printing if (n->fields.empty()) { return Relax(d, "tuple")->Call({}); } ffi::Array fields_doc; - AccessPath fields_p = n_p->Attr("fields"); + ffi::reflection::AccessPath fields_p = n_p->Attr("fields"); for (int i = 0, l = n->fields.size(); i < l; ++i) { fields_doc.push_back(d->AsDoc(n->fields[i], fields_p->ArrayItem(i))); } @@ -64,23 +64,24 @@ TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) .set_dispatch( // - "", [](relax::TupleGetItem n, AccessPath n_p, IRDocsifier d) -> Doc { + "", [](relax::TupleGetItem n, ffi::reflection::AccessPath n_p, IRDocsifier d) -> Doc { ExprDoc idx = LiteralDoc::Int(n->index, n_p->Attr("index")); return d->AsDoc(n->tuple, n_p->Attr("tuple"))[{idx}]; }); TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) .set_dispatch( // - "", [](relax::ShapeExpr n, AccessPath n_p, IRDocsifier d) -> Doc { + "", [](relax::ShapeExpr n, ffi::reflection::AccessPath n_p, IRDocsifier d) -> Doc { ffi::Array values_doc; - AccessPath values_p = n_p->Attr("values"); + ffi::reflection::AccessPath values_p = n_p->Attr("values"); for (int i = 0, l = n->values.size(); i < l; ++i) { values_doc.push_back(PrintShapeVar(n->values[i], values_p->ArrayItem(i), d)); } return Relax(d, "shape")->Call({ListDoc(values_doc)}); }); -ffi::Optional SpecialScalar(const runtime::Tensor& n, const AccessPath& p) { +ffi::Optional SpecialScalar(const runtime::Tensor& n, + const ffi::reflection::AccessPath& p) { DataType dtype = n.DataType(); const void* data = n->data; if (n->ndim != 0 || n->device.device_type != kDLCPU) { @@ -135,7 +136,7 @@ ffi::Optional SpecialScalar(const runtime::Tensor& n, const AccessPath& TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) .set_dispatch( // - "", [](relax::Constant n, AccessPath n_p, IRDocsifier d) -> Doc { + "", [](relax::Constant n, ffi::reflection::AccessPath n_p, IRDocsifier d) -> Doc { if (ffi::Optional s = SpecialScalar(n->data, n_p->Attr("data"))) { if (n->struct_info_.as()) { ExprDoc ann = d->AsDoc(n->struct_info_, n_p->Attr("struct_info_")); @@ -150,7 +151,7 @@ TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) return d->AddMetadata(n); }); -Doc PrintRelaxVar(relax::Var n, AccessPath p, IRDocsifier d) { +Doc PrintRelaxVar(relax::Var n, ffi::reflection::AccessPath p, IRDocsifier d) { if (!d->IsVarDefined(n)) { ExprDoc ann = d->AsDoc(n->struct_info_, p->Attr("struct_info_")); Frame f = d->frames.back(); diff --git a/src/relax/script/printer/function.cc b/src/relax/script/printer/function.cc index e30a2b0bf432..bc4309a7a0d7 100644 --- a/src/relax/script/printer/function.cc +++ b/src/relax/script/printer/function.cc @@ -49,94 +49,99 @@ bool AtTopLevelFunction(const IRDocsifier& d) { TVM_FFI_STATIC_INIT_BLOCK() { RelaxFrameNode::RegisterReflection(); } TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) - .set_dispatch("", [](relax::Function n, AccessPath n_p, IRDocsifier d) -> Doc { - std::unordered_set func_vars; - With f(d); + .set_dispatch( + "", [](relax::Function n, ffi::reflection::AccessPath n_p, IRDocsifier d) -> Doc { + std::unordered_set func_vars; + With f(d); - IdDoc func_name(""); - // if we are binding a local definition, then calling d->Define - // will result in a repeated definition and an incorrect displayed name - if (ffi::Optional name = GetBindingName(d)) { - func_name = IdDoc(name.value()); - } else { - func_name = IdDoc(FindFunctionName(d, n).value_or("main")); - } - (*f)->AddDispatchToken(d, "relax"); - (*f)->is_func = true; - (*f)->func_vars = &func_vars; - // Step 1. Print the return type - ffi::Optional ret_type = std::nullopt; - if (const auto& func_sinfo = relax::MatchStructInfo(n)) { - ret_type = d->AsDoc(func_sinfo.value()->ret, // - n_p->Attr("struct_info_")->Attr("ret")); - } - // Step 2. Print params - ffi::Array params; - { - AccessPath params_p = n_p->Attr("params"); - for (int i = 0, l = n->params.size(); i < l; ++i) { - params.push_back(AssignDoc( - /*lhs=*/DefineVar(n->params[i], *f, d), - /*rhs=*/std::nullopt, - StructInfoAsAnn(n->params[i], params_p->ArrayItem(i), d, std::nullopt))); - } - } - // Step 3. Clean up func variables - (*f)->func_vars = nullptr; - // Step 4. Print attributes - if (n->attrs.defined() && !n->attrs->dict.empty()) { - // If the function is a global function and has a global symbol, - // then don't print the global symbol (it will be implicit from not being private). - // For a function without an IR module whose global symbol - // doesn't match the function name, we should still print the global symbol attribute. - if (AtTopLevelFunction(d) && n->attrs->dict.count(tvm::attr::kGlobalSymbol) && - Downcast(n->attrs->dict.at(tvm::attr::kGlobalSymbol)) == func_name->name) { - ffi::Map new_attrs; - for (auto kv : n->attrs->dict) { - if (kv.first != tvm::attr::kGlobalSymbol) { - new_attrs.Set(kv.first, kv.second); + IdDoc func_name(""); + // if we are binding a local definition, then calling d->Define + // will result in a repeated definition and an incorrect displayed name + if (ffi::Optional name = GetBindingName(d)) { + func_name = IdDoc(name.value()); + } else { + func_name = IdDoc(FindFunctionName(d, n).value_or("main")); + } + (*f)->AddDispatchToken(d, "relax"); + (*f)->is_func = true; + (*f)->func_vars = &func_vars; + // Step 1. Print the return type + ffi::Optional ret_type = std::nullopt; + if (const auto& func_sinfo = relax::MatchStructInfo(n)) { + ret_type = d->AsDoc(func_sinfo.value()->ret, // + n_p->Attr("struct_info_")->Attr("ret")); + } + // Step 2. Print params + ffi::Array params; + { + ffi::reflection::AccessPath params_p = n_p->Attr("params"); + for (int i = 0, l = n->params.size(); i < l; ++i) { + params.push_back(AssignDoc( + /*lhs=*/DefineVar(n->params[i], *f, d), + /*rhs=*/std::nullopt, + StructInfoAsAnn(n->params[i], params_p->ArrayItem(i), d, std::nullopt))); + } + } + // Step 3. Clean up func variables + (*f)->func_vars = nullptr; + // Step 4. Print attributes + if (n->attrs.defined() && !n->attrs->dict.empty()) { + // If the function is a global function and has a global symbol, + // then don't print the global symbol (it will be implicit from not being private). + // For a function without an IR module whose global symbol + // doesn't match the function name, we should still print the global symbol attribute. + if (AtTopLevelFunction(d) && n->attrs->dict.count(tvm::attr::kGlobalSymbol) && + Downcast(n->attrs->dict.at(tvm::attr::kGlobalSymbol)) == + func_name->name) { + ffi::Map new_attrs; + for (auto kv : n->attrs->dict) { + if (kv.first != tvm::attr::kGlobalSymbol) { + new_attrs.Set(kv.first, kv.second); + } + } + if (!new_attrs.empty()) { + (*f)->stmts.push_back(ExprStmtDoc( + Relax(d, "func_attr") // + ->Call({d->AsDoc(DictAttrs(new_attrs), n_p->Attr("attrs"))}))); + } + } else { + (*f)->stmts.push_back( + ExprStmtDoc(Relax(d, "func_attr") // + ->Call({d->AsDoc(n->attrs, n_p->Attr("attrs"))}))); } } - if (!new_attrs.empty()) { - (*f)->stmts.push_back(ExprStmtDoc( - Relax(d, "func_attr") // - ->Call({d->AsDoc(DictAttrs(new_attrs), n_p->Attr("attrs"))}))); + // Step 5. Prepare the decorator (include purity if it's impure) + ExprDoc decorator = Relax(d, "function"); + ffi::Array pos_args = {}; + ffi::Array dec_keys; + ffi::Array dec_values; + if (!n->is_pure) { + dec_keys.push_back("pure"); + dec_values.push_back( + LiteralDoc::Boolean(false, ffi::Optional())); + } + // if the function is global or is not in a module and does not have a global symbol, + // indicate that it's private + if (AtTopLevelFunction(d) && + (!n->attrs.defined() || !n->attrs->dict.count(tvm::attr::kGlobalSymbol))) { + dec_keys.push_back("private"); + dec_values.push_back( + LiteralDoc::Boolean(true, ffi::Optional())); + } + if (dec_keys.size()) { + decorator = decorator->Call(pos_args, dec_keys, dec_values); } - } else { - (*f)->stmts.push_back( - ExprStmtDoc(Relax(d, "func_attr") // - ->Call({d->AsDoc(n->attrs, n_p->Attr("attrs"))}))); - } - } - // Step 5. Prepare the decorator (include purity if it's impure) - ExprDoc decorator = Relax(d, "function"); - ffi::Array pos_args = {}; - ffi::Array dec_keys; - ffi::Array dec_values; - if (!n->is_pure) { - dec_keys.push_back("pure"); - dec_values.push_back(LiteralDoc::Boolean(false, ffi::Optional())); - } - // if the function is global or is not in a module and does not have a global symbol, - // indicate that it's private - if (AtTopLevelFunction(d) && - (!n->attrs.defined() || !n->attrs->dict.count(tvm::attr::kGlobalSymbol))) { - dec_keys.push_back("private"); - dec_values.push_back(LiteralDoc::Boolean(true, ffi::Optional())); - } - if (dec_keys.size()) { - decorator = decorator->Call(pos_args, dec_keys, dec_values); - } - // Step 6. Print body - ffi::Array body = PrintSeqExpr(n->body, n_p->Attr("body"), d, /*use_ret=*/true); - (*f)->stmts.insert((*f)->stmts.end(), body.begin(), body.end()); - return HeaderWrapper(d, FunctionDoc(func_name, params, {decorator}, ret_type, (*f)->stmts)); - }); + // Step 6. Print body + ffi::Array body = PrintSeqExpr(n->body, n_p->Attr("body"), d, /*use_ret=*/true); + (*f)->stmts.insert((*f)->stmts.end(), body.begin(), body.end()); + return HeaderWrapper(d, + FunctionDoc(func_name, params, {decorator}, ret_type, (*f)->stmts)); + }); TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) .set_dispatch( // - "", [](relax::ExternFunc n, AccessPath n_p, IRDocsifier d) -> Doc { + "", [](relax::ExternFunc n, ffi::reflection::AccessPath n_p, IRDocsifier d) -> Doc { ffi::Array args; args.push_back(LiteralDoc::Str(n->global_symbol, n_p->Attr("global_symbol"))); if (!HasDefaultExternFuncStructInfo(n)) { diff --git a/src/relax/script/printer/region.cc b/src/relax/script/printer/region.cc index f5b50e66c97f..561877db8032 100644 --- a/src/relax/script/printer/region.cc +++ b/src/relax/script/printer/region.cc @@ -24,11 +24,11 @@ namespace tvm { namespace script { namespace printer { -ffi::Array PrintSeqExpr(const relax::SeqExpr& n, const AccessPath& n_p, +ffi::Array PrintSeqExpr(const relax::SeqExpr& n, const ffi::reflection::AccessPath& n_p, const IRDocsifier& d, bool use_ret) { With f(d); const ffi::Array& blocks = n->blocks; - AccessPath blocks_p = n_p->Attr("blocks"); + ffi::reflection::AccessPath blocks_p = n_p->Attr("blocks"); ffi::Array* stmts = &(*f)->stmts; for (int i = 0, l = blocks.size(); i < l; ++i) { Doc block = d->AsDoc(blocks[i], blocks_p->ArrayItem(i)); @@ -50,19 +50,21 @@ ffi::Array PrintSeqExpr(const relax::SeqExpr& n, const AccessPath& n_p, } TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) - .set_dispatch("", [](relax::SeqExpr n, AccessPath n_p, IRDocsifier d) -> Doc { - return StmtBlockDoc(PrintSeqExpr(n, n_p, d, false)); - }); + .set_dispatch("", + [](relax::SeqExpr n, ffi::reflection::AccessPath n_p, + IRDocsifier d) -> Doc { + return StmtBlockDoc(PrintSeqExpr(n, n_p, d, false)); + }); -ffi::Array PrintBindingBlock(const relax::BindingBlock& n, const AccessPath& n_p, - const IRDocsifier& d, +ffi::Array PrintBindingBlock(const relax::BindingBlock& n, + const ffi::reflection::AccessPath& n_p, const IRDocsifier& d, ffi::Array* non_dataflow_vars) { const ffi::Array& bindings = n->bindings; - AccessPath bindings_p = n_p->Attr("bindings"); + ffi::reflection::AccessPath bindings_p = n_p->Attr("bindings"); ffi::Array stmts; for (int i = 0, l = bindings.size(); i < l; ++i) { const relax::Binding& binding = bindings[i]; - AccessPath binding_p = bindings_p->ArrayItem(i); + ffi::reflection::AccessPath binding_p = bindings_p->ArrayItem(i); TVM_FFI_ICHECK(binding->var.defined()); Doc binding_doc = d->AsDoc(binding, binding_p); if (const auto* stmt = binding_doc.as()) { @@ -81,13 +83,13 @@ ffi::Array PrintBindingBlock(const relax::BindingBlock& n, const Access TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) .set_dispatch( // - "", [](relax::BindingBlock n, AccessPath n_p, IRDocsifier d) -> Doc { + "", [](relax::BindingBlock n, ffi::reflection::AccessPath n_p, IRDocsifier d) -> Doc { return StmtBlockDoc(PrintBindingBlock(n, n_p, d, nullptr)); }); TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) .set_dispatch( // - "", [](relax::DataflowBlock n, AccessPath n_p, IRDocsifier d) -> Doc { + "", [](relax::DataflowBlock n, ffi::reflection::AccessPath n_p, IRDocsifier d) -> Doc { ffi::Array non_dataflow_vars; ffi::Array stmts = PrintBindingBlock(n, n_p, d, &non_dataflow_vars); stmts.push_back(ExprStmtDoc(Relax(d, "output")->Call(non_dataflow_vars))); diff --git a/src/relax/script/printer/struct_info.cc b/src/relax/script/printer/struct_info.cc index 1019cfa7e9bb..f4f054b24c6a 100644 --- a/src/relax/script/printer/struct_info.cc +++ b/src/relax/script/printer/struct_info.cc @@ -27,11 +27,12 @@ namespace printer { TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) .set_dispatch( // - "", [](relax::ObjectStructInfo n, AccessPath n_p, IRDocsifier d) -> Doc { + "", [](relax::ObjectStructInfo n, ffi::reflection::AccessPath n_p, IRDocsifier d) -> Doc { return Relax(d, "Object"); }); -ExprDoc PrintShapeVar(const PrimExpr& e, const AccessPath& e_p, const IRDocsifier& d) { +ExprDoc PrintShapeVar(const PrimExpr& e, const ffi::reflection::AccessPath& e_p, + const IRDocsifier& d) { ExprDoc expr_doc = d->AsDoc(e, e_p); // Step 1. Find if `func_vars` are being collected const RelaxFrameNode* f = nullptr; @@ -63,7 +64,7 @@ ExprDoc PrintShapeVar(const PrimExpr& e, const AccessPath& e_p, const IRDocsifie TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) .set_dispatch( - "", [](relax::PrimStructInfo n, AccessPath n_p, IRDocsifier d) -> Doc { + "", [](relax::PrimStructInfo n, ffi::reflection::AccessPath n_p, IRDocsifier d) -> Doc { ffi::Array args; ffi::Array kwargs_keys; ffi::Array kwargs_values; @@ -80,10 +81,10 @@ TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) .set_dispatch( - "", [](relax::ShapeStructInfo n, AccessPath n_p, IRDocsifier d) -> Doc { + "", [](relax::ShapeStructInfo n, ffi::reflection::AccessPath n_p, IRDocsifier d) -> Doc { if (n->values.defined()) { ffi::Array shape = n->values.value(); - AccessPath shape_p = n_p->Attr("values"); + ffi::reflection::AccessPath shape_p = n_p->Attr("values"); ffi::Array shape_docs; for (int i = 0, ndim = shape.size(); i < ndim; ++i) { shape_docs.push_back(PrintShapeVar(shape[i], shape_p->ArrayItem(i), d)); @@ -96,7 +97,7 @@ TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) .set_dispatch( // - "", [](relax::TensorStructInfo n, AccessPath n_p, IRDocsifier d) -> Doc { + "", [](relax::TensorStructInfo n, ffi::reflection::AccessPath n_p, IRDocsifier d) -> Doc { ffi::Array args; ffi::Array kwargs_keys; ffi::Array kwargs_values; @@ -104,7 +105,7 @@ TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) // Need to dig into ShapeExpr to preserve the `R.shape` prefix if (const auto* shape = n->shape.value().as()) { auto shape_expr = ffi::GetRef(shape); - AccessPath shape_p = n_p->Attr("shape")->Attr("values"); + ffi::reflection::AccessPath shape_p = n_p->Attr("shape")->Attr("values"); ffi::Array shape_docs; for (int i = 0, ndim = shape_expr->values.size(); i < ndim; ++i) { shape_docs.push_back( @@ -139,12 +140,12 @@ TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) .set_dispatch( // - "", [](relax::TupleStructInfo n, AccessPath n_p, IRDocsifier d) -> Doc { + "", [](relax::TupleStructInfo n, ffi::reflection::AccessPath n_p, IRDocsifier d) -> Doc { if (n->fields.empty()) { return Relax(d, "Tuple"); } ffi::Array fields_doc; - AccessPath fields_p = n_p->Attr("fields"); + ffi::reflection::AccessPath fields_p = n_p->Attr("fields"); for (int i = 0, l = n->fields.size(); i < l; ++i) { fields_doc.push_back(d->AsDoc(n->fields[i], fields_p->ArrayItem(i))); } @@ -153,7 +154,7 @@ TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) .set_dispatch( // - "", [](relax::FuncStructInfo n, AccessPath n_p, IRDocsifier d) -> Doc { + "", [](relax::FuncStructInfo n, ffi::reflection::AccessPath n_p, IRDocsifier d) -> Doc { auto ret_doc = d->AsDoc(n->ret, n_p->Attr("ret")); auto purity_doc = LiteralDoc::Boolean(n->purity, n_p->Attr("purity")); @@ -179,7 +180,7 @@ TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) // TODO(@junrushao): track symbolic shape relation ffi::Array params_doc; ffi::Array params = n->params.value(); - AccessPath params_p = n_p->Attr("params"); + ffi::reflection::AccessPath params_p = n_p->Attr("params"); for (int i = 0, n_params = params.size(); i < n_params; ++i) { params_doc.push_back(d->AsDoc(params[i], params_p->ArrayItem(i))); } diff --git a/src/relax/script/printer/tir.cc b/src/relax/script/printer/tir.cc index e0742f8edd44..345ab9b1dcbf 100644 --- a/src/relax/script/printer/tir.cc +++ b/src/relax/script/printer/tir.cc @@ -42,7 +42,7 @@ RelaxFrameNode* GetRelaxFrame(IRDocsifier d) { return f; } -Doc PrintTIRVar(tirx::Var n, AccessPath n_p, IRDocsifier d) { +Doc PrintTIRVar(tirx::Var n, ffi::reflection::AccessPath n_p, IRDocsifier d) { TVM_FFI_CHECK(n->dtype.is_scalar(), TypeError) << "Relax only uses scalar TIR variables," << "but received TIR variable " << n << " with dtype " << n->dtype; @@ -74,8 +74,8 @@ TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable).set_dispatch("relax", Prin TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable).set_dispatch("relax", PrintTIRVar); TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) - .set_dispatch( // - "relax", [](tvm::IntImm n, AccessPath n_p, IRDocsifier d) -> Doc { // + .set_dispatch( // + "relax", [](tvm::IntImm n, ffi::reflection::AccessPath n_p, IRDocsifier d) -> Doc { // // TODO(@junrushao): support non-int64 cases if (n->dtype.is_bool()) { return LiteralDoc::Boolean(n->value, n_p); @@ -85,8 +85,8 @@ TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) }); TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) - .set_dispatch( // - "relax", [](tvm::GlobalVar n, AccessPath n_p, IRDocsifier d) -> Doc { // + .set_dispatch( // + "relax", [](tvm::GlobalVar n, ffi::reflection::AccessPath n_p, IRDocsifier d) -> Doc { // if (ffi::Optional doc = d->GetVarDoc(n)) { return doc.value(); } else { @@ -97,8 +97,8 @@ TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) }); TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) - .set_dispatch( // - "relax", [](tvm::IRModule mod, AccessPath n_p, IRDocsifier d) -> Doc { // + .set_dispatch( // + "relax", [](tvm::IRModule mod, ffi::reflection::AccessPath n_p, IRDocsifier d) -> Doc { // ffi::Optional doc = d->GetVarDoc(mod); TVM_FFI_ICHECK(doc) << "Unable to print IRModule before definition in Relax."; if (d->cfg->module_alias.empty()) { @@ -118,13 +118,14 @@ TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) }); TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) - .set_dispatch("relax", [](Range range, AccessPath p, IRDocsifier d) -> Doc { - return Relax(d, "Range") - ->Call({ - d->AsDoc(range->min, p->Attr("min")), - d->AsDoc(range->extent + range->min, p->Attr("extent")), - }); - }); + .set_dispatch("relax", + [](Range range, ffi::reflection::AccessPath p, IRDocsifier d) -> Doc { + return Relax(d, "Range") + ->Call({ + d->AsDoc(range->min, p->Attr("min")), + d->AsDoc(range->extent + range->min, p->Attr("extent")), + }); + }); } // namespace printer } // namespace script diff --git a/src/relax/script/printer/type.cc b/src/relax/script/printer/type.cc index f5cbfcb16615..1c01972c517d 100644 --- a/src/relax/script/printer/type.cc +++ b/src/relax/script/printer/type.cc @@ -26,20 +26,20 @@ namespace printer { TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) .set_dispatch( // - "", [](relax::ShapeType n, AccessPath n_p, IRDocsifier d) -> Doc { + "", [](relax::ShapeType n, ffi::reflection::AccessPath n_p, IRDocsifier d) -> Doc { return Relax(d, "Shape") ->Call({}, {"ndim"}, {LiteralDoc::Int(n->ndim, n_p->Attr("ndim"))}); }); TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) .set_dispatch( // - "", [](relax::ObjectType n, AccessPath n_p, IRDocsifier d) -> Doc { + "", [](relax::ObjectType n, ffi::reflection::AccessPath n_p, IRDocsifier d) -> Doc { return Relax(d, "Object"); }); TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) .set_dispatch( // - "", [](relax::TensorType n, AccessPath n_p, IRDocsifier d) -> Doc { + "", [](relax::TensorType n, ffi::reflection::AccessPath n_p, IRDocsifier d) -> Doc { return Relax(d, "Tensor") ->Call({}, {"ndim", "dtype"}, {LiteralDoc::Int(n->ndim, n_p->Attr("ndim")), @@ -48,18 +48,18 @@ TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) .set_dispatch( // - "", [](relax::PackedFuncType n, AccessPath n_p, IRDocsifier d) -> Doc { + "", [](relax::PackedFuncType n, ffi::reflection::AccessPath n_p, IRDocsifier d) -> Doc { return Relax(d, "PackedFunc"); // TODO(@junrushao): verify if this is correct }); TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) .set_dispatch( // - "relax", [](tvm::TupleType n, AccessPath n_p, IRDocsifier d) -> Doc { + "relax", [](tvm::TupleType n, ffi::reflection::AccessPath n_p, IRDocsifier d) -> Doc { if (n->fields.empty()) { return Relax(d, "Tuple"); } ffi::Array fields_doc; - AccessPath fields_p = n_p->Attr("fields"); + ffi::reflection::AccessPath fields_p = n_p->Attr("fields"); for (int i = 0, l = n->fields.size(); i < l; ++i) { fields_doc.push_back(d->AsDoc(n->fields[i], fields_p->ArrayItem(i))); } @@ -68,10 +68,10 @@ TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) .set_dispatch( - "relax", [](tvm::FuncType n, AccessPath n_p, IRDocsifier d) -> Doc { + "relax", [](tvm::FuncType n, ffi::reflection::AccessPath n_p, IRDocsifier d) -> Doc { ffi::Array arg_types_doc; ffi::Array arg_types = n->arg_types; - AccessPath arg_types_p = n_p->Attr("arg_types"); + ffi::reflection::AccessPath arg_types_p = n_p->Attr("arg_types"); for (int i = 0, n_params = arg_types.size(); i < n_params; ++i) { arg_types_doc.push_back(d->AsDoc(arg_types[i], arg_types_p->ArrayItem(i))); } diff --git a/src/relax/script/printer/utils.h b/src/relax/script/printer/utils.h index 607728cb5b69..aabb73abd2cd 100644 --- a/src/relax/script/printer/utils.h +++ b/src/relax/script/printer/utils.h @@ -79,7 +79,8 @@ inline IdDoc DefineVar(const relax::Var& var, const Frame& frame, const IRDocsif return d->Define(var, frame, var->name_hint().empty() ? "v" : var->name_hint()); } -inline ffi::Optional StructInfoAsAnn(const relax::Var& v, const AccessPath& v_p, +inline ffi::Optional StructInfoAsAnn(const relax::Var& v, + const ffi::reflection::AccessPath& v_p, const IRDocsifier& d, const ffi::Optional& rhs) { if (!v->struct_info_.defined()) { @@ -133,10 +134,11 @@ inline ffi::Optional StructInfoAsAnn(const relax::Var& v, const AccessP return d->AsDoc(v->struct_info_, v_p->Attr("struct_info_")); } -ffi::Array PrintSeqExpr(const relax::SeqExpr& n, const AccessPath& n_p, +ffi::Array PrintSeqExpr(const relax::SeqExpr& n, const ffi::reflection::AccessPath& n_p, const IRDocsifier& d, bool use_ret); -ExprDoc PrintShapeVar(const PrimExpr& e, const AccessPath& e_p, const IRDocsifier& d); +ExprDoc PrintShapeVar(const PrimExpr& e, const ffi::reflection::AccessPath& e_p, + const IRDocsifier& d); inline int FindVDeviceIndexByTargetKind(const VDevice& vdevice, const IRDocsifier& d) { ffi::Array vdevices = d->global_infos["vdevice"]; diff --git a/src/relax/transform/fuse_tir.cc b/src/relax/transform/fuse_tir.cc index 7e8ca65b0d7c..bb29a798dc4c 100644 --- a/src/relax/transform/fuse_tir.cc +++ b/src/relax/transform/fuse_tir.cc @@ -237,7 +237,7 @@ class FuseTIRBufferSubstitutor : private StmtExprMutator { auto f_mutate_match_buffers = [this](const MatchBufferRegion& match_buffer) { const Buffer& src_buffer = SubstituteBuffer(match_buffer->source->buffer); const Buffer& tgt_buffer = SubstituteAllocatedBuffer(match_buffer->buffer); - Region region = MutateRegion(match_buffer->source->region); + ffi::Array region = MutateRegion(match_buffer->source->region); if (src_buffer.same_as(match_buffer->source->buffer) && tgt_buffer.same_as(match_buffer->buffer) && region.same_as(match_buffer->source->region)) { @@ -252,7 +252,7 @@ class FuseTIRBufferSubstitutor : private StmtExprMutator { auto f_mutate_read_write_region = [this](const BufferRegion& buffer_region) { const Buffer& buffer = SubstituteBuffer(buffer_region->buffer); - const Region& region = MutateRegion(buffer_region->region); + const ffi::Array& region = MutateRegion(buffer_region->region); if (buffer.same_as(buffer_region->buffer) && region.same_as(buffer_region->region)) { return buffer_region; } else { @@ -302,7 +302,7 @@ class FuseTIRBufferSubstitutor : private StmtExprMutator { // However, `A[vi, vj], A[vi, vj + 1]` is not allow for now. // Note: the order of return region should remain the same as the first occurrence of the region ffi::Array ret; - std::unordered_map buffer_region_set; + std::unordered_map> buffer_region_set; for (const BufferRegion& region : regions) { auto it = buffer_region_set.find(region->buffer.get()); @@ -328,7 +328,7 @@ class FuseTIRBufferSubstitutor : private StmtExprMutator { } } - inline Region MutateRegion(const Region& region) { + inline ffi::Array MutateRegion(const ffi::Array& region) { return MutateArray(region, [this](const Range& range) { const PrimExpr& min = this->VisitExpr(range->min); const PrimExpr& extent = this->VisitExpr(range->extent); diff --git a/src/relax/transform/legalize_ops.cc b/src/relax/transform/legalize_ops.cc index 31e1625526c0..6b618608162d 100644 --- a/src/relax/transform/legalize_ops.cc +++ b/src/relax/transform/legalize_ops.cc @@ -30,6 +30,7 @@ #include #include #include +#include #include #include @@ -235,7 +236,7 @@ class LegalizeMutator : public ExprMutator { Expr VisitExpr_(const CallNode* call) final { Call visited_call = Downcast(this->VisitExprPostOrder_(call)); static const auto& legalize_map = Op::GetAttrMap("FLegalize"); - static const auto& call_packed_map = Op::GetAttrMap("FCallPacked"); + static const auto& call_packed_map = Op::GetAttrMap("FCallPacked"); static const auto& requires_arg_shapes_map = Op::GetAttrMap("RequiresArgumentShapes"); static const Op& call_pure_packed_op = Op::Get("relax.call_pure_packed"); static const Op& call_tir_op = Op::Get("relax.call_tir"); @@ -328,7 +329,7 @@ class LegalizeMutator : public ExprMutator { // Second choice, use a default legalization legalization_func = legalize_map[op]; } else if (call_packed_map.count(op)) { - // Third choice, use an explicit FCallPacked replacement. This does not require the shape + // Third choice, use an explicit ffi::String replacement. This does not require the shape ffi::String packed_func_name = call_packed_map[op]; legalization_func = [packed_func_name](const BlockBuilder& bb, const Call& call) -> Expr { return Call(ExternFunc(packed_func_name), call->args, Attrs(), {GetStructInfo(call)}); diff --git a/src/runtime/contrib/cudnn/conv_backward.cc b/src/runtime/contrib/cudnn/conv_backward.cc index d26f82645eaf..bfc65baaff93 100644 --- a/src/runtime/contrib/cudnn/conv_backward.cc +++ b/src/runtime/contrib/cudnn/conv_backward.cc @@ -24,6 +24,7 @@ #include #include #include +#include #include "cudnn_utils.h" @@ -78,7 +79,7 @@ void BackwardDataFindAlgo(int format, int dims, int groups, const int pad[], con dx_dim_int64[i] = dx_dim[i]; } SetConvDescriptors(entry_ptr, format, dims, groups, pad, stride, dilation, dx_dim_int64.data(), - w_dim_int64.data(), dy_dim_int64.data(), StringToDLDataType(data_dtype), + w_dim_int64.data(), dy_dim_int64.data(), ffi::StringToDLDataType(data_dtype), conv_dtype); int returned_algo_count = 0; @@ -157,7 +158,7 @@ void BackwardFilterFindAlgo(int format, int dims, int groups, const int pad[], c dw_dim_int64[i] = dw_dim[i]; } SetConvDescriptors(entry_ptr, format, dims, groups, pad, stride, dilation, x_dim_int64.data(), - dw_dim_int64.data(), dy_dim_int64.data(), StringToDLDataType(data_dtype), + dw_dim_int64.data(), dy_dim_int64.data(), ffi::StringToDLDataType(data_dtype), conv_dtype); int returned_algo_count = 0; diff --git a/src/runtime/contrib/cudnn/conv_forward.cc b/src/runtime/contrib/cudnn/conv_forward.cc index 6a5737c183b0..6c6fd7eb4036 100644 --- a/src/runtime/contrib/cudnn/conv_forward.cc +++ b/src/runtime/contrib/cudnn/conv_forward.cc @@ -24,6 +24,7 @@ #include #include #include +#include #include "cudnn_utils.h" @@ -123,7 +124,7 @@ void FindAlgo(int format, int dims, int groups, const int pad[], const int strid y_dim_int64[i] = y_dim[i]; } SetConvDescriptors(entry_ptr, format, dims, groups, pad, stride, dilation, x_dim_int64.data(), - w_dim_int64.data(), y_dim_int64.data(), StringToDLDataType(data_dtype), + w_dim_int64.data(), y_dim_int64.data(), ffi::StringToDLDataType(data_dtype), conv_dtype); int returned_algo_count = 0; diff --git a/src/runtime/contrib/nnapi/nnapi_builder.cc b/src/runtime/contrib/nnapi/nnapi_builder.cc index 044ff1ccd4a8..8491e1a75939 100644 --- a/src/runtime/contrib/nnapi/nnapi_builder.cc +++ b/src/runtime/contrib/nnapi/nnapi_builder.cc @@ -22,7 +22,7 @@ #include "nnapi_builder.h" #include -#include +#include #include #include @@ -138,7 +138,7 @@ NNAPIModelBuilder::~NNAPIModelBuilder() { ANeuralNetworksModel_free(model_); } NNAPIOperand NNAPIModelBuilder::CreateOperandWithValue(const DLTensor& tensor) { NNAPIOperand operand(next_operand_index_++, &tensor); - const size_t operand_data_size = GetDataSize(tensor); + const size_t operand_data_size = ffi::GetDataSize(tensor); TVM_FFI_ICHECK_EQ(ANeuralNetworksModel_addOperand(model_, operand.GetOperandType().Get()), ANEURALNETWORKS_NO_ERROR); diff --git a/src/runtime/contrib/nnapi/nnapi_ops.cc b/src/runtime/contrib/nnapi/nnapi_ops.cc index a6b5a9c221a7..4a8bf5ba97aa 100644 --- a/src/runtime/contrib/nnapi/nnapi_ops.cc +++ b/src/runtime/contrib/nnapi/nnapi_ops.cc @@ -273,7 +273,7 @@ void CastOpConverter::Convert(NNAPIModelBuilder& builder, const JSONGraphNode& n // Extract the dtype attribute and check that the output operand type matches the dtype specified. const auto dtype_str = node.GetAttr("astype_dtype"); - const DLDataType dtype = StringToDLDataType(std::string(dtype_str)); + const DLDataType dtype = ffi::StringToDLDataType(std::string(dtype_str)); TVM_FFI_ICHECK(outputs.size() == 1); const auto output_tensor_type = outputs[0].GetTensorType(); TVM_FFI_ICHECK(TensorTypeFromDLDataType(dtype) == output_tensor_type) diff --git a/src/runtime/contrib/nnapi/nnapi_runtime.cc b/src/runtime/contrib/nnapi/nnapi_runtime.cc index 1939e90992e7..46329f201ac2 100644 --- a/src/runtime/contrib/nnapi/nnapi_runtime.cc +++ b/src/runtime/contrib/nnapi/nnapi_runtime.cc @@ -145,7 +145,7 @@ class NNAPIRuntime : public JSONRuntimeBase { const uint32_t eid = EntryID(nid, j); const auto entry = data_entry_[eid]; - const auto operand_data_size = GetDataSize(*entry); + const auto operand_data_size = ffi::GetDataSize(*entry); TVM_FFI_ICHECK_EQ( ANeuralNetworksExecution_setInput(execution, i, operand.GetOperandType().Get(), entry->data, operand_data_size), @@ -161,7 +161,7 @@ class NNAPIRuntime : public JSONRuntimeBase { const auto eid = EntryID(node); const auto entry = data_entry_[eid]; - const auto operand_data_size = GetDataSize(*entry); + const auto operand_data_size = ffi::GetDataSize(*entry); TVM_FFI_ICHECK_EQ( ANeuralNetworksExecution_setOutput(execution, i, operand.GetOperandType().Get(), entry->data, operand_data_size), diff --git a/src/runtime/contrib/sort/sort.cc b/src/runtime/contrib/sort/sort.cc index 541548d18250..0d072c963846 100644 --- a/src/runtime/contrib/sort/sort.cc +++ b/src/runtime/contrib/sort/sort.cc @@ -22,10 +22,10 @@ */ #include +#include #include #include #include -#include #include #include @@ -334,8 +334,8 @@ void RegisterSort() { "input ndim " << input->ndim; - auto data_dtype = DLDataTypeToString(input->dtype); - auto out_dtype = DLDataTypeToString(output->dtype); + auto data_dtype = ffi::DLDataTypeToString(input->dtype); + auto out_dtype = ffi::DLDataTypeToString(output->dtype); TVM_FFI_ICHECK_EQ(data_dtype, out_dtype); diff --git a/src/runtime/contrib/tensorrt/tensorrt_builder.cc b/src/runtime/contrib/tensorrt/tensorrt_builder.cc index 63d886e520a7..4caa8e383e15 100644 --- a/src/runtime/contrib/tensorrt/tensorrt_builder.cc +++ b/src/runtime/contrib/tensorrt/tensorrt_builder.cc @@ -24,6 +24,7 @@ #include "tensorrt_builder.h" +#include #include #include @@ -227,10 +228,10 @@ nvinfer1::Weights TensorRTBuilder::GetDLTensorAsWeights(const DLTensor* dptr, const auto trt_dtype = (static_cast(dptr->dtype.bits) == 16) ? nvinfer1::DataType::kHALF : nvinfer1::DataType::kFLOAT; - const size_t weight_bytes = GetDataSize(*dptr); + const size_t weight_bytes = ffi::GetDataSize(*dptr); nvinfer1::Weights weight{trt_dtype, nullptr, 0}; size_t count = 1; - for (tvm_index_t i = 0; i < dptr->ndim; ++i) { + for (ffi::Shape::index_type i = 0; i < dptr->ndim; ++i) { count *= dptr->shape[i]; } weight.count = count; diff --git a/src/runtime/cuda/cuda_device_api.cc b/src/runtime/cuda/cuda_device_api.cc index a01d223ff6f3..5de47bd3e431 100644 --- a/src/runtime/cuda/cuda_device_api.cc +++ b/src/runtime/cuda/cuda_device_api.cc @@ -27,6 +27,7 @@ #include #include #include +#include #include #include @@ -454,7 +455,7 @@ TVM_FFI_STATIC_INIT_BLOCK() { break; default: TVM_FFI_THROW(InternalError) - << "Unsupported data type " << runtime::DLDataTypeToString(tensor_dtype); + << "Unsupported data type " << ffi::DLDataTypeToString(tensor_dtype); } break; case DataType::kUInt: @@ -474,7 +475,7 @@ TVM_FFI_STATIC_INIT_BLOCK() { break; default: TVM_FFI_THROW(InternalError) - << "Unsupported data type " << runtime::DLDataTypeToString(tensor_dtype); + << "Unsupported data type " << ffi::DLDataTypeToString(tensor_dtype); } break; case DataType::kFloat: @@ -491,7 +492,7 @@ TVM_FFI_STATIC_INIT_BLOCK() { break; default: TVM_FFI_THROW(InternalError) - << "Unsupported data type " << runtime::DLDataTypeToString(tensor_dtype); + << "Unsupported data type " << ffi::DLDataTypeToString(tensor_dtype); } break; case DataType::kBFloat: @@ -502,7 +503,7 @@ TVM_FFI_STATIC_INIT_BLOCK() { break; default: TVM_FFI_THROW(InternalError) - << "Unsupported data type " << runtime::DLDataTypeToString(tensor_dtype); + << "Unsupported data type " << ffi::DLDataTypeToString(tensor_dtype); } break; case DataType::kFloat8_e4m3fn: @@ -515,7 +516,7 @@ TVM_FFI_STATIC_INIT_BLOCK() { break; default: TVM_FFI_THROW(InternalError) - << "Unsupported data type " << runtime::DLDataTypeToString(tensor_dtype); + << "Unsupported data type " << ffi::DLDataTypeToString(tensor_dtype); } // sanity checks per cuTensorMapEncodeTiled requirements diff --git a/src/runtime/device_api.cc b/src/runtime/device_api.cc index d8aff594ea95..959cd619abbc 100644 --- a/src/runtime/device_api.cc +++ b/src/runtime/device_api.cc @@ -132,7 +132,7 @@ void* DeviceAPI::AllocDataSpace(Device dev, int ndim, const int64_t* shape, DLDa temp.shape = const_cast(shape); temp.strides = nullptr; temp.byte_offset = 0; - size_t size = GetDataSize(temp); + size_t size = ffi::GetDataSize(temp); size_t alignment = GetDataAlignment(temp.dtype); return AllocDataSpace(dev, size, alignment, dtype); } @@ -143,8 +143,8 @@ void* DeviceAPI::AllocDataSpace(Device dev, int ndim, const int64_t* shape, DLDa void DeviceAPI::CopyDataFromTo(DLTensor* from, DLTensor* to, TVMStreamHandle stream) { // by default, we can always redirect to the flat memory copy operation. - size_t nbytes = GetDataSize(*from); - TVM_FFI_ICHECK_EQ(nbytes, GetDataSize(*to)); + size_t nbytes = ffi::GetDataSize(*from); + TVM_FFI_ICHECK_EQ(nbytes, ffi::GetDataSize(*to)); TVM_FFI_ICHECK(ffi::IsContiguous(*from) && ffi::IsContiguous(*to)) << "CopyDataFromTo only support contiguous array for now"; diff --git a/src/runtime/hexagon/hexagon_device_api.cc b/src/runtime/hexagon/hexagon_device_api.cc index 0d1c432571c6..ae0e0862dfc2 100644 --- a/src/runtime/hexagon/hexagon_device_api.cc +++ b/src/runtime/hexagon/hexagon_device_api.cc @@ -168,7 +168,7 @@ void HexagonDeviceAPI::FreeWorkspace(Device dev, void* data) { void HexagonDeviceAPI::CopyDataFromTo(DLTensor* from, DLTensor* to, TVMStreamHandle stream) { TVM_FFI_ICHECK_EQ(from->byte_offset, 0); TVM_FFI_ICHECK_EQ(to->byte_offset, 0); - TVM_FFI_ICHECK_EQ(GetDataSize(*from), GetDataSize(*to)); + TVM_FFI_ICHECK_EQ(ffi::GetDataSize(*from), ffi::GetDataSize(*to)); TVM_FFI_ICHECK(runtime_hexbuffs) << "Attempted to copy Hexagon data with " << "HexagonDeviceAPI::CopyDataFromTo before initializing resources. " @@ -182,11 +182,11 @@ void HexagonDeviceAPI::CopyDataFromTo(DLTensor* from, DLTensor* to, TVMStreamHan HexagonBuffer* hex_to_buf = lookup_hexagon_buffer(to->data); if (hex_from_buf && hex_to_buf) { - hex_to_buf->CopyFrom(*hex_from_buf, GetDataSize(*from)); + hex_to_buf->CopyFrom(*hex_from_buf, ffi::GetDataSize(*from)); } else if (hex_to_buf) { - hex_to_buf->CopyFrom(from->data, GetDataSize(*from)); + hex_to_buf->CopyFrom(from->data, ffi::GetDataSize(*from)); } else if (hex_from_buf) { - hex_from_buf->CopyTo(to->data, GetDataSize(*to)); + hex_from_buf->CopyTo(to->data, ffi::GetDataSize(*to)); } else { TVM_FFI_ICHECK(false) << "CopyDataFromTo requested between src and dst which are not managed by the " diff --git a/src/runtime/metadata.h b/src/runtime/metadata.h index e85d53b07cbe..c034041ce4a4 100644 --- a/src/runtime/metadata.h +++ b/src/runtime/metadata.h @@ -74,7 +74,7 @@ class FunctionInfoObj : public ffi::Object { obj.Set("name", name); json::Array sarg_types; for (const auto& t : arg_types) { - sarg_types.push_back(ffi::String(DLDataTypeToString(t))); + sarg_types.push_back(ffi::String(ffi::DLDataTypeToString(t))); } obj.Set("arg_types", std::move(sarg_types)); { @@ -96,7 +96,7 @@ class FunctionInfoObj : public ffi::Object { auto sarg_types_arr = src.at("arg_types").cast(); arg_types = ffi::Array(); for (size_t i = 0; i < sarg_types_arr.size(); ++i) { - arg_types.push_back(StringToDLDataType(std::string(sarg_types_arr[i].cast()))); + arg_types.push_back(ffi::StringToDLDataType(std::string(sarg_types_arr[i].cast()))); } auto lt = src.find("launch_param_tags"); if (lt != src.end()) { diff --git a/src/runtime/opencl/opencl_device_api.cc b/src/runtime/opencl/opencl_device_api.cc index 0b63f497dbff..952a9b67141c 100644 --- a/src/runtime/opencl/opencl_device_api.cc +++ b/src/runtime/opencl/opencl_device_api.cc @@ -23,6 +23,7 @@ #include #include #include +#include #include @@ -507,9 +508,9 @@ void OpenCLWorkspace::FreeDataSpace(Device dev, void* ptr) { void OpenCLWorkspace::CopyDataFromTo(DLTensor* from, DLTensor* to, TVMStreamHandle stream) { this->Init(); - size_t nbytes = GetDataSize(*from); - TVM_FFI_ICHECK_EQ(nbytes, GetDataSize(*to)); - TVM_FFI_ICHECK(IsContiguous(*from) && IsContiguous(*to)) + size_t nbytes = ffi::GetDataSize(*from); + TVM_FFI_ICHECK_EQ(nbytes, ffi::GetDataSize(*to)); + TVM_FFI_ICHECK(ffi::IsContiguous(*from) && ffi::IsContiguous(*to)) << "CopyDataFromTo only support contiguous array for now"; if (IsOpenCLDevice(from->device) && IsOpenCLDevice(to->device)) { diff --git a/src/runtime/rpc/rpc_device_api.cc b/src/runtime/rpc/rpc_device_api.cc index 579b45abb31c..6e0dd162b3ba 100644 --- a/src/runtime/rpc/rpc_device_api.cc +++ b/src/runtime/rpc/rpc_device_api.cc @@ -23,7 +23,7 @@ #include #include #include -#include +#include #include @@ -98,14 +98,14 @@ class RPCDeviceAPI final : public DeviceAPI { from_tensor.device = RemoveRPCSessionMask(dev_from); from_tensor.data = static_cast(from->data)->data; void* to_bytes = static_cast(to->data) + to->byte_offset; - size_t nbytes = GetDataSize(*to); + size_t nbytes = ffi::GetDataSize(*to); GetSess(dev_from)->CopyFromRemote(&from_tensor, to_bytes, nbytes); } else if (dev_from.device_type == kDLCPU && IsRPCSessionDevice(dev_to)) { DLTensor to_tensor = *to; to_tensor.device = RemoveRPCSessionMask(dev_to); to_tensor.data = static_cast(to->data)->data; void* from_bytes = static_cast(from->data) + from->byte_offset; - size_t nbytes = GetDataSize(*from); + size_t nbytes = ffi::GetDataSize(*from); GetSess(dev_to)->CopyToRemote(from_bytes, &to_tensor, nbytes); } else { TVM_FFI_THROW(InternalError) << "expect copy from/to remote or between remote"; diff --git a/src/runtime/rpc/rpc_local_session.cc b/src/runtime/rpc/rpc_local_session.cc index 5094bc678bac..0a670ceb941c 100644 --- a/src/runtime/rpc/rpc_local_session.cc +++ b/src/runtime/rpc/rpc_local_session.cc @@ -106,7 +106,7 @@ void LocalSession::CallFunc(RPCSession::PackedFuncHandle func, ffi::PackedArgs a } void LocalSession::CopyToRemote(void* from_bytes, DLTensor* to, uint64_t nbytes) { - TVM_FFI_ICHECK_EQ(nbytes, GetDataSize(*to)); + TVM_FFI_ICHECK_EQ(nbytes, ffi::GetDataSize(*to)); DLTensor from; from.data = from_bytes; from.device = {kDLCPU, 0}; diff --git a/src/runtime/rpc/rpc_module.cc b/src/runtime/rpc/rpc_module.cc index 3bf7da474ee2..74f19ab3e3bb 100644 --- a/src/runtime/rpc/rpc_module.cc +++ b/src/runtime/rpc/rpc_module.cc @@ -399,7 +399,7 @@ inline void CPUCacheFlushImpl(const char* addr, unsigned int len) { inline void CPUCacheFlush(int begin_index, const ffi::PackedArgs& args) { for (int i = begin_index; i < args.size(); i++) { CPUCacheFlushImpl(static_cast((args[i].cast()->data)), - GetDataSize(*(args[i].cast()))); + ffi::GetDataSize(*(args[i].cast()))); } } diff --git a/src/runtime/tensor.cc b/src/runtime/tensor.cc index d4fe1772b978..61f037caec55 100644 --- a/src/runtime/tensor.cc +++ b/src/runtime/tensor.cc @@ -25,7 +25,7 @@ #include #include #include -#include +#include #include #include "tvm/runtime/data_type.h" @@ -60,9 +60,9 @@ inline void VerifyDataType(DLDataType dtype) { } void TensorCopyFromBytes(DLTensor* handle, const void* data, size_t nbytes) { - size_t arr_size = GetDataSize(*handle); + size_t arr_size = ffi::GetDataSize(*handle); TVM_FFI_ICHECK_EQ(arr_size, nbytes) << "TensorCopyFromBytes: size mismatch"; - TVM_FFI_ICHECK(IsContiguous(*handle)) + TVM_FFI_ICHECK(ffi::IsContiguous(*handle)) << "TensorCopyFromBytes only support contiguous array for now"; DLTensor from; @@ -80,7 +80,7 @@ void TensorCopyFromBytes(DLTensor* handle, const void* data, size_t nbytes) { void Tensor::CopyToBytes(const DLTensor* handle, void* data, size_t nbytes, TVMStreamHandle stream) { - size_t arr_size = GetDataSize(*handle); + size_t arr_size = ffi::GetDataSize(*handle); TVM_FFI_ICHECK_EQ(arr_size, nbytes) << "ArrayCopyToBytes: size mismatch"; TVM_FFI_ICHECK(ffi::IsContiguous(*handle)) << "ArrayCopyToBytes only support contiguous array for now"; @@ -101,7 +101,7 @@ void Tensor::CopyToBytes(const DLTensor* handle, void* data, size_t nbytes, void Tensor::CopyFromBytes(const DLTensor* handle, void* data, size_t nbytes, TVMStreamHandle stream) { - size_t arr_size = GetDataSize(*handle); + size_t arr_size = ffi::GetDataSize(*handle); TVM_FFI_ICHECK_EQ(arr_size, nbytes) << "ArrayCopyToBytes: size mismatch"; TVM_FFI_ICHECK(ffi::IsContiguous(*handle)) << "ArrayCopyToBytes only support contiguous array for now"; @@ -160,7 +160,7 @@ Tensor Tensor::CreateView(ffi::Shape shape, DLDataType dtype, uint64_t relative_ return ss.str(); }(); const auto& curr_dl_tensor = *get_mutable(); - size_t curr_size = GetDataSize(curr_dl_tensor); + size_t curr_size = ffi::GetDataSize(curr_dl_tensor); size_t view_size = ffi::GetDataSize(shape.Product(), dtype); TVM_FFI_CHECK_LE(relative_byte_offset + view_size, curr_size, ValueError) << "View with shape " << shape << " and datatype " << dtype << " would have a size of " @@ -215,8 +215,8 @@ Tensor Tensor::CopyTo(const Device& dev, ffi::Optional mem_scope) c } void Tensor::CopyFromTo(const DLTensor* from, DLTensor* to, TVMStreamHandle stream) { - size_t from_size = GetDataSize(*from); - size_t to_size = GetDataSize(*to); + size_t from_size = ffi::GetDataSize(*from); + size_t to_size = ffi::GetDataSize(*to); TVM_FFI_ICHECK_EQ(from_size, to_size) << "TVMTensorCopyFromTo: The size in bytes must exactly match."; diff --git a/src/runtime/vm/builtin.cc b/src/runtime/vm/builtin.cc index 706d74339097..f5485e7a3326 100644 --- a/src/runtime/vm/builtin.cc +++ b/src/runtime/vm/builtin.cc @@ -27,7 +27,7 @@ #include #include #include -#include +#include #include #include #include @@ -611,7 +611,7 @@ bool ReadIfCond(ffi::AnyView cond) { break; } default: - TVM_FFI_THROW(InternalError) << "Unknown scalar int type: " << DLDataTypeToString(arr->dtype); + TVM_FFI_THROW(InternalError) << "Unknown scalar int type: " << ffi::DLDataTypeToString(arr->dtype); throw; } return result != 0; @@ -702,7 +702,7 @@ TVM_FFI_STATIC_INIT_BLOCK() { } default: TVM_FFI_THROW(InternalError) - << "Unknown scalar int type: " << DLDataTypeToString(arr->dtype); + << "Unknown scalar int type: " << ffi::DLDataTypeToString(arr->dtype); throw; } out_shape.push_back(result); diff --git a/src/runtime/vm/hexagon/builtin.cc b/src/runtime/vm/hexagon/builtin.cc index 54fd70b2800f..c7429975647f 100644 --- a/src/runtime/vm/hexagon/builtin.cc +++ b/src/runtime/vm/hexagon/builtin.cc @@ -45,8 +45,8 @@ TVM_FFI_STATIC_INIT_BLOCK() { void* src = sptr->data; int ret = DMA_RETRY; - TVM_FFI_ICHECK_EQ(GetDataSize(*dptr), GetDataSize(*sptr)); - auto size = GetDataSize(*dptr); + TVM_FFI_ICHECK_EQ(ffi::GetDataSize(*dptr), ffi::GetDataSize(*sptr)); + auto size = ffi::GetDataSize(*dptr); TVM_FFI_ICHECK(size > 0); if (bypass_cache) qurt_mem_cache_clean(reinterpret_cast(src), size, @@ -65,7 +65,7 @@ TVM_FFI_STATIC_INIT_BLOCK() { if (bypass_cache) { const DLTensor* dptr = dst_arr.operator->(); void* dst = dptr->data; - auto size = GetDataSize(*dptr); + auto size = ffi::GetDataSize(*dptr); qurt_mem_cache_clean(reinterpret_cast(dst), size, QURT_MEM_CACHE_FLUSH, QURT_MEM_DCACHE); } diff --git a/src/runtime/vm/tensor_cache_support.cc b/src/runtime/vm/tensor_cache_support.cc index 1804dcc622b1..ee77c5ddd8f0 100644 --- a/src/runtime/vm/tensor_cache_support.cc +++ b/src/runtime/vm/tensor_cache_support.cc @@ -137,7 +137,7 @@ void CopyTensorFromBytes(Tensor param, const void* data, size_t nbytes, // It creates a host side memory mirror, for every cl_mem that tries to copy data from host // which can cause memory issue. Her we use a large staging buffer to postpone deallocation if (staging_buffer->defined()) { - size_t curr_size = runtime::GetDataSize(*(staging_buffer->value().operator->())); + size_t curr_size = ffi::GetDataSize(*(staging_buffer->value().operator->())); if (curr_size < nbytes) { *staging_buffer = std::nullopt; } diff --git a/src/runtime/vm/vm.cc b/src/runtime/vm/vm.cc index d3fa356792ee..b7e29710aff9 100644 --- a/src/runtime/vm/vm.cc +++ b/src/runtime/vm/vm.cc @@ -26,6 +26,7 @@ #include #include #include +#include #include @@ -779,7 +780,7 @@ void VirtualMachineImpl::RunInstrCall(VMFrame* curr_frame, Instruction instr) { for (int i = 0; i < instr.num_args; ++i) { if (call_args[i + args_begin_offset].type_index() == ffi::TypeIndex::kTVMFFIDataType) { std::string str_dtype = - DLDataTypeToString(call_args[i + args_begin_offset].cast()); + ffi::DLDataTypeToString(call_args[i + args_begin_offset].cast()); temp_dtype.emplace_back(std::make_unique(str_dtype)); call_args[i + args_begin_offset] = *temp_dtype.back(); } diff --git a/src/s_tir/analysis/is_pure_function.cc b/src/s_tir/analysis/is_pure_function.cc index 40feab7f5c80..2ca557b171d1 100644 --- a/src/s_tir/analysis/is_pure_function.cc +++ b/src/s_tir/analysis/is_pure_function.cc @@ -33,7 +33,6 @@ namespace tvm { namespace s_tir { using namespace tvm::tirx; -using AccessPath = ffi::reflection::AccessPath; namespace { class PurityChecker : TIRVisitorWithPath { @@ -47,12 +46,12 @@ class PurityChecker : TIRVisitorWithPath { private: explicit PurityChecker(bool assert_on_error) : assert_on_error_(assert_on_error) {} - void VisitStmt_(const AllocBufferNode* op, AccessPath path) override { + void VisitStmt_(const AllocBufferNode* op, ffi::reflection::AccessPath path) override { internal_allocations_.insert(op->buffer->data); TIRVisitorWithPath::VisitStmt_(op, path); } - void VisitStmt_(const BufferStoreNode* op, AccessPath path) override { + void VisitStmt_(const BufferStoreNode* op, ffi::reflection::AccessPath path) override { TIRVisitorWithPath::VisitStmt_(op, path); if (!internal_allocations_.count(op->buffer->data)) { @@ -65,7 +64,7 @@ class PurityChecker : TIRVisitorWithPath { } } - void VisitExpr_(const CallNode* call, AccessPath path) override { + void VisitExpr_(const CallNode* call, ffi::reflection::AccessPath path) override { TIRVisitorWithPath::VisitExpr_(call, path); static auto op_call_effect = Op::GetAttrMap("TCallEffectKind"); diff --git a/src/s_tir/analysis/sblock_access_region_detector.cc b/src/s_tir/analysis/sblock_access_region_detector.cc index 19fc8fd090fb..918b9d2815d8 100644 --- a/src/s_tir/analysis/sblock_access_region_detector.cc +++ b/src/s_tir/analysis/sblock_access_region_detector.cc @@ -204,7 +204,7 @@ void BlockReadWriteDetector::VisitExpr_(const CallNode* op) { if (it != buffer_var_map_.end()) { const Buffer& buffer = (*it).second; const BufferRegion buffer_region = BufferRegion::FullRegion(buffer); - const Region& region = buffer_region->region; + const ffi::Array& region = buffer_region->region; std::vector int_set; int_set.reserve(region.size()); for (const Range& range : region) { @@ -287,7 +287,7 @@ std::vector BlockReadWriteDetector::ConvertMatchedRegion( const MatchBufferRegion& match_buffer, const std::vector& int_sets) const { const Buffer& buffer = match_buffer->buffer; - Region region; + ffi::Array region; region.reserve(int_sets.size()); TVM_FFI_ICHECK_EQ(buffer->shape.size(), int_sets.size()); for (size_t i = 0; i < int_sets.size(); ++i) { @@ -363,7 +363,7 @@ void BlockReadWriteDetector::UpdateOpaque(const Var& buffer_var) { if (it != buffer_var_map_.end()) { const Buffer& buffer = (*it).second; const BufferRegion buffer_region = BufferRegion::FullRegion(buffer); - const Region& region = buffer_region->region; + const ffi::Array& region = buffer_region->region; std::vector int_set; int_set.reserve(region.size()); for (const Range& range : region) { diff --git a/src/s_tir/meta_schedule/arg_info.cc b/src/s_tir/meta_schedule/arg_info.cc index 5411d46cc20b..87c6715a9841 100644 --- a/src/s_tir/meta_schedule/arg_info.cc +++ b/src/s_tir/meta_schedule/arg_info.cc @@ -126,7 +126,7 @@ TensorInfo::TensorInfo(runtime::DataType dtype, ffi::Shape shape) { ffi::ObjectRef TensorInfoNode::AsJSON() const { static ffi::String tag = "TENSOR"; - ffi::String dtype = DLDataTypeToString(this->dtype); + ffi::String dtype = ffi::DLDataTypeToString(this->dtype); ffi::Array shape = support::AsArray(this->shape); return ffi::Array{tag, dtype, shape}; } @@ -140,7 +140,7 @@ TensorInfo TensorInfo::FromJSON(const ffi::ObjectRef& json_obj) { // Load json[1] => dtype { ffi::String dtype_str = json_array->at(1).cast(); - dtype = StringToDLDataType(dtype_str); + dtype = ffi::StringToDLDataType(dtype_str); } // Load json[2] => shape shape = AsIntArray(json_array->at(2).cast()); diff --git a/src/s_tir/meta_schedule/schedule_rule/multi_level_tiling_tensor_core.cc b/src/s_tir/meta_schedule/schedule_rule/multi_level_tiling_tensor_core.cc index 00e51cbb1ebf..2dc9de361e8f 100644 --- a/src/s_tir/meta_schedule/schedule_rule/multi_level_tiling_tensor_core.cc +++ b/src/s_tir/meta_schedule/schedule_rule/multi_level_tiling_tensor_core.cc @@ -820,7 +820,7 @@ ffi::Optional MultiLevelTilingTensorCoreNode::TransformWithTensorIntrin( rhs_to_index_map_tgt[mapping_info->rhs_iters[i - offset]->var] = index_map->final_indices[i]; } - auto f_get_sub_index_map = [&](const tirx::Buffer& lhs_buffer, const tirx::Region& lhs_region) { + auto f_get_sub_index_map = [&](const tirx::Buffer& lhs_buffer, const ffi::Array& lhs_region) { std::vector sub_index_map_src; std::vector sub_index_map_tgt; const tirx::Buffer& rhs_buffer = mapping_info->lhs_buffer_map[lhs_buffer]; diff --git a/src/s_tir/schedule/primitive/cache_index.cc b/src/s_tir/schedule/primitive/cache_index.cc index 9566817f8015..e162117b67b0 100644 --- a/src/s_tir/schedule/primitive/cache_index.cc +++ b/src/s_tir/schedule/primitive/cache_index.cc @@ -289,7 +289,7 @@ ffi::Array MakeIndexCacheStage(IndexInfo* info, const ffi::String& stora // block variables ffi::Array block_vars; // block access region for write buffers - Region access_region; + ffi::Array access_region; // indices used in block body ffi::Array access_indices; ffi::Map block_var_map; diff --git a/src/s_tir/schedule/primitive/cache_read_write.cc b/src/s_tir/schedule/primitive/cache_read_write.cc index 39d3bacfbe8b..d5015b75a318 100644 --- a/src/s_tir/schedule/primitive/cache_read_write.cc +++ b/src/s_tir/schedule/primitive/cache_read_write.cc @@ -182,17 +182,17 @@ SBlock MakeReindexCacheStage(const BufferRegion& cache_region, ReindexCacheStage } // block access region for read/write buffers - Region read_access_region, write_access_region; + ffi::Array read_access_region, write_access_region; ffi::Array read_access_indices, write_access_indices; // Compute read/write region and read/write access indices. ffi::Array& old_indices = (is_cache_read) ? read_access_indices : write_access_indices; - Region& old_region = (is_cache_read) ? read_access_region : write_access_region; + ffi::Array& old_region = (is_cache_read) ? read_access_region : write_access_region; for (const Range& range : cache_region->region) { old_indices.push_back(Substitute(range->min, var_map)); old_region.push_back(Range::FromMinExtent(old_indices.back(), Integer(1))); } ffi::Array& new_indices = (is_cache_read) ? write_access_indices : read_access_indices; - Region& new_region = (is_cache_read) ? write_access_region : read_access_region; + ffi::Array& new_region = (is_cache_read) ? write_access_region : read_access_region; for (const PrimExpr& idx : info->indices) { new_indices.push_back(Substitute((idx), var_map)); new_region.push_back(Range::FromMinExtent(new_indices.back(), Integer(1))); @@ -254,8 +254,8 @@ SBlock MakeCacheStage(const BufferRegion& cache_region, CacheStageInfo* info, // block variables ffi::Array block_vars; // block access region for read/write buffers - Region read_access_region; - Region write_access_region; + ffi::Array read_access_region; + ffi::Array write_access_region; // indices used in block body ffi::Array read_access_indices; ffi::Array write_access_indices; @@ -384,8 +384,8 @@ SBlock MakeReIndexStage(const SBlock& block, CacheStageInfo* info, // Step 3: Create the reindex block // The src and the dst region and indices of the data copy - Region src_region{nullptr}; - Region dst_region{nullptr}; + ffi::Array src_region{nullptr}; + ffi::Array dst_region{nullptr}; ffi::Array src_indices{nullptr}; ffi::Array dst_indices{nullptr}; @@ -635,7 +635,7 @@ BufferRegion RelaxBufferRegion(ScheduleState self, const BufferRegion& buffer_re /*analyzer=*/&analyzer); TVM_FFI_ICHECK_EQ(buffer_region->region.size(), int_sets.size()); - Region region; + ffi::Array region; region.reserve(int_sets.size()); for (size_t i = 0; i < int_sets.size(); ++i) { region.push_back(int_sets[i].CoverRange(Range::FromMinExtent(0, buffer->shape[i]))); @@ -901,7 +901,7 @@ class CacheReadRewriter : public StmtExprMutator { explicit CacheReadRewriter(const StmtSRef& scope_sref, CacheStageInfo* info, bool cache_full_region = true) : scope_sref_(scope_sref), info_(info), cache_full_region_(cache_full_region) { - auto update_region = [this](const Region& region, const Region& offset) -> Region { + auto update_region = [this](const ffi::Array& region, const ffi::Array& offset) -> ffi::Array { TVM_FFI_ICHECK_EQ(region.size(), offset.size()); std::vector ret; for (size_t i = 0; i < region.size(); ++i) { @@ -1158,7 +1158,7 @@ class CacheWriteRewriter : public StmtExprMutator { writer_block_sref_(writer_block_sref), info_(info), cache_full_region_(cache_full_region) { - auto update_region = [this](const Region& region, const Region& offset) -> Region { + auto update_region = [this](const ffi::Array& region, const ffi::Array& offset) -> ffi::Array { TVM_FFI_ICHECK_EQ(region.size(), offset.size()); std::vector ret; for (size_t i = 0; i < region.size(); ++i) { @@ -1680,7 +1680,7 @@ class ReIndexRewriter : public StmtExprMutator { /*! \brief The new indices */ ffi::Array indices_; /*! \brief The new region */ - Region region_; + ffi::Array region_; }; void CheckRegionCover(const ScheduleState& self, StmtSRef scope_root, Buffer read_buffer) { diff --git a/src/s_tir/schedule/primitive/decompose_padding.cc b/src/s_tir/schedule/primitive/decompose_padding.cc index ee2045b7eef6..c7a6ce1ceeb2 100644 --- a/src/s_tir/schedule/primitive/decompose_padding.cc +++ b/src/s_tir/schedule/primitive/decompose_padding.cc @@ -313,7 +313,7 @@ static std::pair CreateInBoundBlock(const SBlockRealizeNode auto rewrite_expr = [&repl_dict, analyzer](const PrimExpr& e) { return analyzer->Simplify(Substitute(e, repl_dict)); }; - auto rewrite_region = [rewrite_expr](const Region& region) { + auto rewrite_region = [rewrite_expr](const ffi::Array& region) { return region.Map([rewrite_expr](const Range& r) { return Range::FromMinExtent(rewrite_expr(r->min), rewrite_expr(r->extent)); }); diff --git a/src/s_tir/schedule/primitive/rolling_buffer.cc b/src/s_tir/schedule/primitive/rolling_buffer.cc index 85e4d3b2a8bb..5c2b1a985da3 100644 --- a/src/s_tir/schedule/primitive/rolling_buffer.cc +++ b/src/s_tir/schedule/primitive/rolling_buffer.cc @@ -44,7 +44,7 @@ BufferRegion GetRelaxedBufferRegion(const SBlockRealize& realize, const BufferRe const ffi::Map& dom_map) { ffi::Array relaxed_intsets = arith::EvalSet(Substitute(buffer_region->region, GetBindings(realize)), dom_map); - Region relaxed_region; + ffi::Array relaxed_region; relaxed_region.reserve(relaxed_intsets.size()); for (size_t i = 0; i < relaxed_intsets.size(); ++i) { relaxed_region.push_back( @@ -165,7 +165,7 @@ class RollingBufferInfoCollector { private: bool MatchRollingBuffer(const StmtSRef& block_sref, const BufferRegion& buffer_region) { const Buffer& buffer = buffer_region->buffer; - const Region& region = buffer_region->region; + const ffi::Array& region = buffer_region->region; std::vector> bound_iter_vars; std::vector bound_overlaps; diff --git a/src/s_tir/support/nd_int_set.h b/src/s_tir/support/nd_int_set.h index 03f3672b452d..df9aa8e3dc64 100644 --- a/src/s_tir/support/nd_int_set.h +++ b/src/s_tir/support/nd_int_set.h @@ -36,7 +36,7 @@ using NDIntSet = std::vector; * \param region The region. * \return The constructed set. */ -inline NDIntSet NDIntSetFromRegion(const tirx::Region& region) { +inline NDIntSet NDIntSetFromRegion(const ffi::Array& region) { NDIntSet result; result.reserve(region.size()); for (const Range& range : region) { diff --git a/src/s_tir/transform/compact_buffer_region.cc b/src/s_tir/transform/compact_buffer_region.cc index d01f24c670a4..c4e68d24bd89 100644 --- a/src/s_tir/transform/compact_buffer_region.cc +++ b/src/s_tir/transform/compact_buffer_region.cc @@ -47,7 +47,7 @@ using namespace tvm::tirx; using support::NDIntSet; /*! \brief a more constrained bound estimate for n-dimentional int set */ -NDIntSet NDIntSetEval(Region region, PrimExpr predicate, +NDIntSet NDIntSetEval(ffi::Array region, PrimExpr predicate, const std::unordered_map& dom_map, arith::Analyzer* analyzer) { std::unordered_map var_dom; @@ -111,7 +111,7 @@ class Var2BufferCollector : public StmtExprVisitor { */ class BufferAccessRegionCollector : public StmtExprVisitor { public: - static std::unordered_map Collect( + static std::unordered_map, ffi::ObjectPtrHash, ffi::ObjectPtrEqual> Collect( const PrimFunc& f, bool collect_inbound) { BufferAccessRegionCollector region_collector(collect_inbound); @@ -528,7 +528,7 @@ class BufferAccessRegionCollector : public StmtExprVisitor { * The entire access region should get updated on the buffer's define point * and we sanity check that every buffer is defined only once. */ - std::unordered_map buffer_access_region_; + std::unordered_map, ffi::ObjectPtrHash, ffi::ObjectPtrEqual> buffer_access_region_; /*! \brief The map from Buffer to it's access regions annotated by current block. */ std::unordered_map, ffi::ObjectPtrHash, ffi::ObjectPtrEqual> @@ -548,7 +548,7 @@ struct DimAlignInfo { struct BufferAllocInfo { /*! \brief The buffer access region. */ - Region region; + ffi::Array region; /*! \brief The storage alignment information. */ std::vector dim_aligns; /*! @@ -644,7 +644,7 @@ class BufferCompactor : public StmtExprMutator { *indices = std::move(new_indices); } - void RewriteBufferRegion(Buffer* buffer, Region* region) const { + void RewriteBufferRegion(Buffer* buffer, ffi::Array* region) const { auto it = buffer_info_.find((*buffer)->data); if (it == buffer_info_.end()) { // Skip if the buffer is parameter @@ -652,7 +652,7 @@ class BufferCompactor : public StmtExprMutator { } const BufferAllocInfo& info = it->second; TVM_FFI_ICHECK_EQ(region->size(), info.region.size()); - Region new_region; + ffi::Array new_region; new_region.reserve(info.region.size()); for (size_t i = 0; i < info.region.size(); ++i) { const Range& range = (*region)[i]; @@ -716,14 +716,14 @@ ffi::Array CalcStrides(const BufferAllocInfo& alloc_info, Stmt BufferCompactorCompact( const PrimFunc& f, - const std::unordered_map& regions, + const std::unordered_map, ffi::ObjectPtrHash, ffi::ObjectPtrEqual>& regions, const std::unordered_map& storage_align) { // collect buffer allocation info for no-alias buffers std::unordered_map buffer_info; for (const auto& kv : regions) { const Buffer& buffer = kv.first; // set dim alignment info - Region region = kv.second; + ffi::Array region = kv.second; BufferAllocInfo alloc_info; auto it = storage_align.find(buffer->data); if (it != storage_align.end()) { diff --git a/src/s_tir/transform/inject_software_pipeline.cc b/src/s_tir/transform/inject_software_pipeline.cc index 151264405207..14997709b8b5 100644 --- a/src/s_tir/transform/inject_software_pipeline.cc +++ b/src/s_tir/transform/inject_software_pipeline.cc @@ -237,7 +237,7 @@ class PipelineBodyRewriter : public StmtExprMutator { BufferRegion RewritePipelineBufferRegion(const BufferRegion& buffer_region) const { auto it = buffer_remap_.find(buffer_region->buffer); if (it != buffer_remap_.end()) { - Region new_region = buffer_region->region; + ffi::Array new_region = buffer_region->region; const Buffer& new_buffer = (*it).second; // For pipeline buffers, relax the access region of the first dimension to full extent // if access_all_versions == true @@ -444,7 +444,7 @@ class PipelineRewriter : public StmtExprMutator { * \param region2 The second region. * \return Whether region1 and region2 have intersections. */ - bool MayConflict(Region region1, Region region2) { + bool MayConflict(ffi::Array region1, ffi::Array region2) { TVM_FFI_ICHECK(region1.size() == region2.size()); for (size_t i = 0; i < region1.size(); i++) { Range dim1 = region1[i]; @@ -1203,7 +1203,7 @@ class PipelineInjector : private StmtExprMutator { void AddAllocBuffers(SBlockNode* n, const ffi::Array alloc_buffers) { for (const Buffer& alloc_buffer : alloc_buffers) { n->alloc_buffers.push_back(alloc_buffer); - Region region; + ffi::Array region; region.reserve(alloc_buffer->shape.size()); for (const PrimExpr& dim : alloc_buffer->shape) { region.push_back(Range::FromMinExtent(0, dim)); diff --git a/src/s_tir/transform/lower_match_buffer.cc b/src/s_tir/transform/lower_match_buffer.cc index bd27d5189321..ac23bb87d537 100644 --- a/src/s_tir/transform/lower_match_buffer.cc +++ b/src/s_tir/transform/lower_match_buffer.cc @@ -29,6 +29,7 @@ #include #include #include +#include #include "../../tirx/ir/functor_common.h" #include "../../tirx/transform/ir_utils.h" @@ -153,7 +154,7 @@ class MatchBufferLower : public StmtExprMutator { return buffer_region; } else { const BufferRegion& source = (*it).second; - Region region = ConvertRegion(MatchBufferRegion(buffer, source), buffer_region->region); + ffi::Array region = ConvertRegion(MatchBufferRegion(buffer, source), buffer_region->region); return BufferRegion(source->buffer, std::move(region)); } } diff --git a/src/s_tir/transform/memhammer_lower_auto_copy.cc b/src/s_tir/transform/memhammer_lower_auto_copy.cc index 3836256449f9..b3536be6619a 100644 --- a/src/s_tir/transform/memhammer_lower_auto_copy.cc +++ b/src/s_tir/transform/memhammer_lower_auto_copy.cc @@ -573,7 +573,7 @@ class AutoPadder { Buffer src_buffer = r->source->buffer; runtime::StorageScope scope = runtime::StorageScope::Create(src_buffer.scope()); if (scope.rank == runtime::StorageRank::kShared) { - Region region = r->source->region; + ffi::Array region = r->source->region; ffi::Array indices; for (int i = 0; i < static_cast(region.size()); i++) { Var var("region" + std::to_string(i)); diff --git a/src/script/printer/doc.cc b/src/script/printer/doc.cc index 8ff66df53bde..9bc92a9e1095 100644 --- a/src/script/printer/doc.cc +++ b/src/script/printer/doc.cc @@ -20,7 +20,7 @@ #include #include #include -#include +#include #include namespace tvm { @@ -83,7 +83,7 @@ StmtBlockDoc::StmtBlockDoc(ffi::Array stmts) { this->data_ = std::move(n); } -LiteralDoc::LiteralDoc(ffi::Any value, const ffi::Optional& object_path) { +LiteralDoc::LiteralDoc(ffi::Any value, const ffi::Optional& object_path) { ffi::ObjectPtr n = ffi::make_object(); n->value = value; if (object_path.defined()) { @@ -273,7 +273,7 @@ TVM_FFI_STATIC_INIT_BLOCK() { namespace refl = tvm::ffi::reflection; refl::GlobalDef().def( "script.printer.DocSetSourcePaths", - [](Doc doc, ffi::Array source_paths) { doc->source_paths = source_paths; }); + [](Doc doc, ffi::Array source_paths) { doc->source_paths = source_paths; }); } TVM_FFI_STATIC_INIT_BLOCK() { diff --git a/src/script/printer/doc_printer/base_doc_printer.cc b/src/script/printer/doc_printer/base_doc_printer.cc index ad81297f97be..30754a66ee51 100644 --- a/src/script/printer/doc_printer/base_doc_printer.cc +++ b/src/script/printer/doc_printer/base_doc_printer.cc @@ -264,7 +264,7 @@ DocPrinter::DocPrinter(const PrinterConfig& options) : options_(options) { void DocPrinter::Append(const Doc& doc) { Append(doc, PrinterConfig()); } void DocPrinter::Append(const Doc& doc, const PrinterConfig& cfg) { - for (const AccessPath& p : cfg->path_to_underline) { + for (const ffi::reflection::AccessPath& p : cfg->path_to_underline) { path_to_underline_.push_back(p); current_max_path_depth_.push_back(0); current_underline_candidates_.push_back(std::vector()); @@ -348,15 +348,15 @@ void DocPrinter::PrintDoc(const Doc& doc) { } size_t end_pos = output_.tellp(); - for (const AccessPath& path : doc->source_paths) { + for (const ffi::reflection::AccessPath& path : doc->source_paths) { MarkSpan({start_pos, end_pos}, path); } } -void DocPrinter::MarkSpan(const ByteSpan& span, const AccessPath& path) { +void DocPrinter::MarkSpan(const ByteSpan& span, const ffi::reflection::AccessPath& path) { int n = path_to_underline_.size(); for (int i = 0; i < n; ++i) { - AccessPath p = path_to_underline_[i]; + ffi::reflection::AccessPath p = path_to_underline_[i]; if (path->depth >= current_max_path_depth_[i] && path->IsPrefixOf(p)) { if (path->depth > current_max_path_depth_[i]) { current_max_path_depth_[i] = path->depth; diff --git a/src/script/printer/doc_printer/base_doc_printer.h b/src/script/printer/doc_printer/base_doc_printer.h index 6708ce156b20..cbad586d558e 100644 --- a/src/script/printer/doc_printer/base_doc_printer.h +++ b/src/script/printer/doc_printer/base_doc_printer.h @@ -255,7 +255,7 @@ class DocPrinter { std::vector underlines_exempted_; private: - void MarkSpan(const ByteSpan& span, const AccessPath& path); + void MarkSpan(const ByteSpan& span, const ffi::reflection::AccessPath& path); /*! \brief Options to customize certain aspects of the output */ PrinterConfig options_; @@ -267,7 +267,7 @@ class DocPrinter { std::vector line_starts_; /*! \brief Path of the object that we would like to underline */ - ffi::Array path_to_underline_; + ffi::Array path_to_underline_; /*! * \brief Candidate spans to be underlined, until we find a better match. diff --git a/src/script/printer/doc_printer/python_doc_printer.cc b/src/script/printer/doc_printer/python_doc_printer.cc index db04e7427acd..78b9b9fa986f 100644 --- a/src/script/printer/doc_printer/python_doc_printer.cc +++ b/src/script/printer/doc_printer/python_doc_printer.cc @@ -18,7 +18,7 @@ */ #include #include -#include +#include #include #include diff --git a/src/script/printer/ir/distributed.cc b/src/script/printer/ir/distributed.cc index 5abc316154e0..60c0e3ceaf7e 100644 --- a/src/script/printer/ir/distributed.cc +++ b/src/script/printer/ir/distributed.cc @@ -24,7 +24,7 @@ namespace script { namespace printer { TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) - .set_dispatch("", [](ffi::Shape n, AccessPath n_p, IRDocsifier d) -> Doc { + .set_dispatch("", [](ffi::Shape n, ffi::reflection::AccessPath n_p, IRDocsifier d) -> Doc { int s = n.size(); ffi::Array results; results.reserve(s); diff --git a/src/script/printer/ir/ir.cc b/src/script/printer/ir/ir.cc index a9b998d03eb2..4029863aeeaa 100644 --- a/src/script/printer/ir/ir.cc +++ b/src/script/printer/ir/ir.cc @@ -56,7 +56,7 @@ struct SortableFunction { }; TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) - .set_dispatch("", [](IRModule mod, AccessPath p, IRDocsifier d) -> Doc { + .set_dispatch("", [](IRModule mod, ffi::reflection::AccessPath p, IRDocsifier d) -> Doc { std::vector functions; for (const auto& kv : mod->functions) { functions.push_back(SortableFunction(kv)); @@ -113,22 +113,22 @@ TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) }); TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) - .set_dispatch("", [](DictAttrs attrs, AccessPath p, IRDocsifier d) -> Doc { + .set_dispatch("", [](DictAttrs attrs, ffi::reflection::AccessPath p, IRDocsifier d) -> Doc { return d->AsDoc(attrs->dict, p->Attr("dict")); }); TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) - .set_dispatch("", [](GlobalVar gv, AccessPath p, IRDocsifier d) -> Doc { + .set_dispatch("", [](GlobalVar gv, ffi::reflection::AccessPath p, IRDocsifier d) -> Doc { return IR(d, "GlobalVar")->Call({LiteralDoc::Str(gv->name_hint, p->Attr("name_hint"))}); }); TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) - .set_dispatch("", [](GlobalInfo ginfo, AccessPath p, IRDocsifier d) -> Doc { + .set_dispatch("", [](GlobalInfo ginfo, ffi::reflection::AccessPath p, IRDocsifier d) -> Doc { return IR(d, "dummy_global_info")->Call({}); }); TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) - .set_dispatch("", [](VDevice vdev, AccessPath p, IRDocsifier d) -> Doc { + .set_dispatch("", [](VDevice vdev, ffi::reflection::AccessPath p, IRDocsifier d) -> Doc { d->AddGlobalInfo("vdevice", vdev); ffi::Map config = vdev->target->ToConfig(); return IR(d, "vdevice") @@ -138,12 +138,12 @@ TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) }); TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) - .set_dispatch("", [](Op op, AccessPath p, IRDocsifier d) -> Doc { + .set_dispatch("", [](Op op, ffi::reflection::AccessPath p, IRDocsifier d) -> Doc { return IR(d, "Op")->Call({LiteralDoc::Str(op->name, p->Attr("name"))}); }); TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) - .set_dispatch("", [](FuncType func_type, AccessPath p, IRDocsifier d) -> Doc { + .set_dispatch("", [](FuncType func_type, ffi::reflection::AccessPath p, IRDocsifier d) -> Doc { return IR(d, "FuncType") ->Call({ d->AsDoc(func_type->arg_types, p->Attr("arg_types")), @@ -152,7 +152,7 @@ TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) }); TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) - .set_dispatch("ir", [](Range range, AccessPath p, IRDocsifier d) -> Doc { + .set_dispatch("ir", [](Range range, ffi::reflection::AccessPath p, IRDocsifier d) -> Doc { return IR(d, "Range") ->Call({ d->AsDoc(range->min, p->Attr("min")), diff --git a/src/script/printer/ir/misc.cc b/src/script/printer/ir/misc.cc index f33170577154..64eb69bf5668 100644 --- a/src/script/printer/ir/misc.cc +++ b/src/script/printer/ir/misc.cc @@ -24,7 +24,7 @@ namespace printer { TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) .set_dispatch>( // - "", [](ffi::Array array, AccessPath p, IRDocsifier d) -> Doc { + "", [](ffi::Array array, ffi::reflection::AccessPath p, IRDocsifier d) -> Doc { int n = array.size(); ffi::Array results; results.reserve(n); @@ -36,7 +36,7 @@ TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) .set_dispatch>( // - "", [](ffi::Map dict, AccessPath p, IRDocsifier d) -> Doc { + "", [](ffi::Map dict, ffi::reflection::AccessPath p, IRDocsifier d) -> Doc { using POO = std::pair; std::vector items{dict.begin(), dict.end()}; bool is_str_map = true; diff --git a/src/script/printer/ir_docsifier.cc b/src/script/printer/ir_docsifier.cc index 5fb247a4882a..76631c169c24 100644 --- a/src/script/printer/ir_docsifier.cc +++ b/src/script/printer/ir_docsifier.cc @@ -19,7 +19,7 @@ #include #include #include -#include +#include #include #include @@ -209,7 +209,7 @@ IRDocsifier::FType& IRDocsifier::vtable() { } TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) - .set_fallback([](ffi::ObjectRef obj, AccessPath p, IRDocsifier d) -> Doc { + .set_fallback([](ffi::ObjectRef obj, ffi::reflection::AccessPath p, IRDocsifier d) -> Doc { return d->AddMetadata(obj); }); diff --git a/src/script/printer/utils.h b/src/script/printer/utils.h index 84d1854b756d..41bed45a0552 100644 --- a/src/script/printer/utils.h +++ b/src/script/printer/utils.h @@ -27,6 +27,7 @@ #include #include #include +#include #include #include @@ -59,7 +60,7 @@ inline std::string RedirectedReprPrinterMethod(const ffi::ObjectRef& obj) { inline std::string Docsify(const ffi::ObjectRef& obj, const IRDocsifier& d, const Frame& f, const PrinterConfig& cfg) { - Doc doc = d->AsDoc(obj, AccessPath::Root()); + Doc doc = d->AsDoc(obj, ffi::reflection::AccessPath::Root()); bool move_source_paths = false; if (const auto* expr_doc = doc.as()) { if (!cfg->verbose_expr) { @@ -122,7 +123,7 @@ inline ExprDoc Relax(const IRDocsifier& d, const ffi::String& attr) { } inline std::string DType2Str(const runtime::DataType& dtype) { - return dtype.is_void() ? "void" : runtime::DLDataTypeToString(dtype); + return dtype.is_void() ? "void" : ffi::DLDataTypeToString(dtype); } /*! \brief Add headers as comments to doc if needed */ diff --git a/src/target/codegen.cc b/src/target/codegen.cc index f24cb6a49497..39500a0451fa 100644 --- a/src/target/codegen.cc +++ b/src/target/codegen.cc @@ -35,6 +35,7 @@ #include #include +#include #include #include #include @@ -109,7 +110,7 @@ class ModuleSerializer { uint64_t module_index = 0; auto fpush_imports_to_stack = [&](ffi::ModuleObj* node) { - for (Any m : node->imports()) { + for (ffi::Any m : node->imports()) { ffi::ModuleObj* next = m.cast().operator->(); if (visited.count(next) == 0) { visited.insert(next); @@ -177,7 +178,7 @@ class ModuleSerializer { for (size_t parent_index = 0; parent_index < mod_group_vec_.size(); ++parent_index) { child_indices.clear(); for (const auto* m : mod_group_vec_[parent_index]) { - for (Any im : m->imports()) { + for (ffi::Any im : m->imports()) { uint64_t mod_index = mod2index_.at(im.cast().operator->()); // skip cycle when dso modules are merged together if (mod_index != parent_index) { diff --git a/src/target/cuda/intrin_rule_cuda.cc b/src/target/cuda/intrin_rule_cuda.cc index d38db9fe8372..fc2d78a30710 100644 --- a/src/target/cuda/intrin_rule_cuda.cc +++ b/src/target/cuda/intrin_rule_cuda.cc @@ -247,7 +247,7 @@ TVM_REGISTER_OP("tirx.cuda.__shfl_sync") .add_argument("var", "Expr", "The variable to sync.") .add_argument("lane", "Expr", "The source thread id.") .add_argument("width", "Expr", "The warp thread width, must be a power of 2.") - .set_attr("TGlobalSymbol", "__shfl_sync") + .set_attr("TGlobalSymbol", "__shfl_sync") .set_attr("TCallEffectKind", Integer(CallEffectKind::kOpaque)) .set_attr("cuda.need_warp_shuffle", true); @@ -257,7 +257,7 @@ TVM_REGISTER_OP("tirx.cuda.__shfl_up_sync") .add_argument("var", "Expr", "The variable to sync.") .add_argument("delta", "Expr", "The source lane id offset to be added.") .add_argument("width", "Expr", "The warp thread width, must be a power of 2.") - .set_attr("TGlobalSymbol", "__shfl_up_sync") + .set_attr("TGlobalSymbol", "__shfl_up_sync") .set_attr("TCallEffectKind", Integer(CallEffectKind::kOpaque)) .set_attr("cuda.need_warp_shuffle", true); @@ -267,13 +267,13 @@ TVM_REGISTER_OP("tirx.cuda.__shfl_down_sync") .add_argument("var", "Expr", "The variable to sync.") .add_argument("delta", "Expr", "The source lane id offset to be subtracted.") .add_argument("width", "Expr", "The warp thread width, must be a power of 2.") - .set_attr("TGlobalSymbol", "__shfl_down_sync") + .set_attr("TGlobalSymbol", "__shfl_down_sync") .set_attr("TCallEffectKind", Integer(CallEffectKind::kOpaque)) .set_attr("cuda.need_warp_shuffle", true); TVM_REGISTER_OP("tirx.cuda.__activemask") .set_num_inputs(0) - .set_attr("TGlobalSymbol", "__activemask") + .set_attr("TGlobalSymbol", "__activemask") .set_attr("TCallEffectKind", Integer(CallEffectKind::kPure)) .set_attr("cuda.need_warp_shuffle", true); diff --git a/src/target/llvm/codegen_llvm.cc b/src/target/llvm/codegen_llvm.cc index 06c521a85502..a0e237500c19 100644 --- a/src/target/llvm/codegen_llvm.cc +++ b/src/target/llvm/codegen_llvm.cc @@ -81,6 +81,7 @@ #include #include #include +#include #include #include @@ -2265,7 +2266,7 @@ llvm::DIType* CodeGenLLVM::GetDebugType(const Type& ty_tir, llvm::Type* ty_llvm) if (dtype.is_scalable_vector()) return nullptr; - return dbg_info_->di_builder_->createBasicType(DLDataTypeToString(dtype).operator std::string(), + return dbg_info_->di_builder_->createBasicType(ffi::DLDataTypeToString(dtype).operator std::string(), dtype.bits() * dtype.lanes(), dwarf_type); } else { diff --git a/src/target/llvm/codegen_llvm.h b/src/target/llvm/codegen_llvm.h index b57a1a446bcf..e6c3176867d5 100644 --- a/src/target/llvm/codegen_llvm.h +++ b/src/target/llvm/codegen_llvm.h @@ -563,7 +563,7 @@ class CodeGenLLVM : public ExprFunctor, llvm::DISubprogram* di_subprogram_{nullptr}; // Cache potential common path ops to slightly improve lookup time. // global symbol table. - OpAttrMap op_attr_global_symbol_ = Op::GetAttrMap("TGlobalSymbol"); + OpAttrMap op_attr_global_symbol_ = Op::GetAttrMap("TGlobalSymbol"); const Op& builtin_call_extern_ = builtin::call_extern(); const Op& builtin_call_pure_extern_ = builtin::call_pure_extern(); const Op& builtin_call_llvm_intrin_ = builtin::call_llvm_intrin(); diff --git a/src/target/metal/intrin_rule_metal.cc b/src/target/metal/intrin_rule_metal.cc index cea19519ca7f..94f4c0fbe308 100644 --- a/src/target/metal/intrin_rule_metal.cc +++ b/src/target/metal/intrin_rule_metal.cc @@ -143,21 +143,21 @@ TVM_REGISTER_OP("tirx.metal.simd_shuffle") .set_num_inputs(2) .add_argument("var", "Expr", "The variable to sync.") .add_argument("lane", "Expr", "The source thread id.") - .set_attr("TGlobalSymbol", "simd_shuffle") + .set_attr("TGlobalSymbol", "simd_shuffle") .set_attr("TCallEffectKind", Integer(CallEffectKind::kOpaque)); TVM_REGISTER_OP("tirx.metal.simd_shuffle_up") .set_num_inputs(2) .add_argument("var", "Expr", "The variable to sync.") .add_argument("delta", "Expr", "The source lane id offset to be added.") - .set_attr("TGlobalSymbol", "simd_shuffle_up") + .set_attr("TGlobalSymbol", "simd_shuffle_up") .set_attr("TCallEffectKind", Integer(CallEffectKind::kOpaque)); TVM_REGISTER_OP("tirx.metal.simd_shuffle_down") .set_num_inputs(2) .add_argument("var", "Expr", "The variable to sync.") .add_argument("delta", "Expr", "The source lane id offset to be subtracted.") - .set_attr("TGlobalSymbol", "simd_shuffle_down") + .set_attr("TGlobalSymbol", "simd_shuffle_down") .set_attr("TCallEffectKind", Integer(CallEffectKind::kOpaque)); } // namespace intrin diff --git a/src/target/source/codegen_c.h b/src/target/source/codegen_c.h index 29c5e420997e..0914abc79dff 100644 --- a/src/target/source/codegen_c.h +++ b/src/target/source/codegen_c.h @@ -316,7 +316,7 @@ class CodeGenC : public ExprFunctor, /*! \brief the data type of allocated buffers */ std::unordered_map handle_data_type_; /*! \brief Record of ops that have pre-defined global symbol. */ - OpAttrMap op_attr_global_symbol_ = Op::GetAttrMap("TGlobalSymbol"); + OpAttrMap op_attr_global_symbol_ = Op::GetAttrMap("TGlobalSymbol"); // cache commonly used ops const Op& builtin_call_extern_ = builtin::call_extern(); const Op& builtin_call_pure_extern_ = builtin::call_pure_extern(); diff --git a/src/target/virtual_device.cc b/src/target/virtual_device.cc index c7357d7f14f7..6c83acbe4e2f 100644 --- a/src/target/virtual_device.cc +++ b/src/target/virtual_device.cc @@ -68,7 +68,7 @@ TVM_FFI_STATIC_INIT_BLOCK() { } VirtualDevice::VirtualDevice(int device_type_int, int virtual_device_id, Target target, - MemoryScope memory_scope) { + ffi::String memory_scope) { TVM_FFI_ICHECK(!target.defined() || device_type_int == target->GetTargetDeviceType()) << "target " << target->str() << " has device type " << target->GetTargetDeviceType() << " but virtual device has device type " << device_type_int; @@ -118,7 +118,7 @@ ffi::Optional VirtualDevice::Join(const VirtualDevice& lhs, } else { joined_target = rhs->target; } - MemoryScope joined_memory_scope; + ffi::String joined_memory_scope; if (!lhs->memory_scope.empty()) { joined_memory_scope = lhs->memory_scope; if (!rhs->memory_scope.empty() && lhs->memory_scope != rhs->memory_scope) { @@ -158,7 +158,7 @@ VirtualDevice VirtualDevice::Default(const VirtualDevice& lhs, const VirtualDevi } // else: leave as null } - MemoryScope defaulted_memory_scope; + ffi::String defaulted_memory_scope; if (!lhs->memory_scope.empty()) { defaulted_memory_scope = lhs->memory_scope; } else { @@ -169,7 +169,7 @@ VirtualDevice VirtualDevice::Default(const VirtualDevice& lhs, const VirtualDevi } VirtualDevice VirtualDeviceCache::Make(int device_type, int virtual_device_id, Target target, - MemoryScope memory_scope) { + ffi::String memory_scope) { VirtualDevice prototype(device_type, virtual_device_id, std::move(target), std::move(memory_scope)); if (prototype->IsFullyUnconstrained()) { diff --git a/src/target/webgpu/intrin_rule_webgpu.cc b/src/target/webgpu/intrin_rule_webgpu.cc index bc48395468d3..ce958a466d6c 100644 --- a/src/target/webgpu/intrin_rule_webgpu.cc +++ b/src/target/webgpu/intrin_rule_webgpu.cc @@ -163,21 +163,21 @@ TVM_REGISTER_OP("tirx.webgpu.subgroup_shuffle") .set_num_inputs(2) .add_argument("var", "Expr", "The variable to sync.") .add_argument("lane", "Expr", "The source thread id.") - .set_attr("TGlobalSymbol", "subgroupShuffle") + .set_attr("TGlobalSymbol", "subgroupShuffle") .set_attr("TCallEffectKind", Integer(CallEffectKind::kOpaque)); TVM_REGISTER_OP("tirx.webgpu.subgroup_shuffle_up") .set_num_inputs(2) .add_argument("var", "Expr", "The variable to sync.") .add_argument("delta", "Expr", "The source lane id offset to be added.") - .set_attr("TGlobalSymbol", "subgroupShuffleUp") + .set_attr("TGlobalSymbol", "subgroupShuffleUp") .set_attr("TCallEffectKind", Integer(CallEffectKind::kOpaque)); TVM_REGISTER_OP("tirx.webgpu.subgroup_shuffle_down") .set_num_inputs(2) .add_argument("var", "Expr", "The variable to sync.") .add_argument("delta", "Expr", "The source lane id offset to be subtracted.") - .set_attr("TGlobalSymbol", "subgroupShuffleDown") + .set_attr("TGlobalSymbol", "subgroupShuffleDown") .set_attr("TCallEffectKind", Integer(CallEffectKind::kOpaque)); } // namespace intrin diff --git a/src/tirx/analysis/verify_memory.cc b/src/tirx/analysis/verify_memory.cc index d97be6c5daae..6c4ba1193400 100644 --- a/src/tirx/analysis/verify_memory.cc +++ b/src/tirx/analysis/verify_memory.cc @@ -29,6 +29,7 @@ #include #include #include +#include namespace tvm { namespace tirx { diff --git a/src/tirx/analysis/verify_well_formed.cc b/src/tirx/analysis/verify_well_formed.cc index 6f9858dd074e..cc33e59d5690 100644 --- a/src/tirx/analysis/verify_well_formed.cc +++ b/src/tirx/analysis/verify_well_formed.cc @@ -40,7 +40,6 @@ namespace tvm { namespace tirx { -using AccessPath = ffi::reflection::AccessPath; namespace { @@ -234,19 +233,19 @@ class UndefinedVarVerifier : public Verifier { private: using Verifier::Visit; - void Visit(const PrimFunc& prim_func, AccessPath path) override { + void Visit(const PrimFunc& prim_func, ffi::reflection::AccessPath path) override { Verifier::Visit(prim_func, path); redefine_allowed_within_function_.clear(); } - void EnterDef(const IterVar& iter_var, AccessPath path) override { + void EnterDef(const IterVar& iter_var, ffi::reflection::AccessPath path) override { Verifier::EnterDef(iter_var, path); if (iter_var->iter_type == IterVarType::kThreadIndex) { redefine_allowed_within_function_.insert(iter_var->var); } } - void EnterDef(const Var& var, AccessPath path) override { + void EnterDef(const Var& var, ffi::reflection::AccessPath path) override { bool redefine_is_allowed = redefine_allowed_within_function_.count(var); { auto it = currently_defined_.find(var); @@ -272,14 +271,14 @@ class UndefinedVarVerifier : public Verifier { currently_defined_.insert({var, path}); } - void ExitDef(const Var& var, AccessPath path) override { + void ExitDef(const Var& var, ffi::reflection::AccessPath path) override { auto active_def = currently_defined_.find(var); currently_defined_.erase(active_def); previously_defined_.insert({var, path}); } - void VisitExpr_(const VarNode* op, AccessPath path) override { + void VisitExpr_(const VarNode* op, ffi::reflection::AccessPath path) override { auto var = ffi::GetRef(op); auto active_def = currently_defined_.find(var); @@ -298,10 +297,10 @@ class UndefinedVarVerifier : public Verifier { } // Variables that are defined in the currently-visited scope. - std::unordered_map currently_defined_; + std::unordered_map currently_defined_; // Variables that were previously defined, and are now out of scope. - std::unordered_map previously_defined_; + std::unordered_map previously_defined_; // Special variables that are allowed to be re-defined, so long as // that re-definition occurs within the same PrimFunc. For example @@ -328,20 +327,20 @@ class UndefinedBufferVerifier : public Verifier { private: using Verifier::Visit; - void Visit(const PrimFunc& prim_func, AccessPath path) override { + void Visit(const PrimFunc& prim_func, ffi::reflection::AccessPath path) override { Verifier::Visit(prim_func, path); // Clear per-function state (buffers should not cross function boundaries). currently_defined_.clear(); previously_defined_.clear(); } - void EnterDef(const Buffer& buffer, AccessPath path) override { + void EnterDef(const Buffer& buffer, ffi::reflection::AccessPath path) override { // Call the base class to visit buffer's internal vars (shape, strides, etc.) Verifier::EnterDef(buffer, path); currently_defined_.insert({buffer, path}); } - void ExitDef(const Buffer& buffer, AccessPath path) override { + void ExitDef(const Buffer& buffer, ffi::reflection::AccessPath path) override { auto active_def = currently_defined_.find(buffer); if (active_def != currently_defined_.end()) { currently_defined_.erase(active_def); @@ -349,7 +348,7 @@ class UndefinedBufferVerifier : public Verifier { previously_defined_.insert({buffer, path}); } - void VisitBufferUse(const Buffer& buffer, AccessPath path) override { + void VisitBufferUse(const Buffer& buffer, ffi::reflection::AccessPath path) override { bool is_declared = currently_defined_.count(buffer); bool was_declared = previously_defined_.count(buffer); @@ -369,10 +368,10 @@ class UndefinedBufferVerifier : public Verifier { } // Buffers defined in the currently-visited scope. - std::unordered_map + std::unordered_map currently_defined_; // Buffers that were previously defined and are now out of scope. - std::unordered_map + std::unordered_map previously_defined_; }; @@ -389,12 +388,12 @@ class SingleEnvThreadVerifier : public Verifier { using Verifier::Verifier; private: - void Visit(const PrimFunc& prim_func, AccessPath path) override { + void Visit(const PrimFunc& prim_func, ffi::reflection::AccessPath path) override { Verifier::Visit(prim_func, path); env_thread_vars_.clear(); } - void EnterDef(const IterVar& iter_var, AccessPath path) override { + void EnterDef(const IterVar& iter_var, ffi::reflection::AccessPath path) override { if (iter_var->iter_type == IterVarType::kThreadIndex) { if (auto it = env_thread_vars_.find(iter_var->thread_tag); it != env_thread_vars_.end()) { const auto& [prev_var, prev_path] = it->second; @@ -413,7 +412,7 @@ class SingleEnvThreadVerifier : public Verifier { } } - std::unordered_map> env_thread_vars_; + std::unordered_map> env_thread_vars_; }; bool VerifyWellFormed(const PrimFunc& func, bool assert_mode) { diff --git a/src/tirx/ir/stmt.cc b/src/tirx/ir/stmt.cc index 983cdaa9c602..4e2cfd3f1474 100644 --- a/src/tirx/ir/stmt.cc +++ b/src/tirx/ir/stmt.cc @@ -510,7 +510,7 @@ MatchBufferRegion::MatchBufferRegion(Buffer buffer, BufferRegion source) { // Validate shape TVM_FFI_ICHECK(source->region.size() >= buffer->shape.size()) - << "Dimension of source Region expected to be larger or equal than target buffer shape, but " + << "Dimension of source ffi::Array expected to be larger or equal than target buffer shape, but " "got " << source->region.size() << " vs. " << buffer->shape.size(); size_t offset = source->region.size() - buffer->shape.size(); diff --git a/src/tirx/ir/tir_visitor_with_path.cc b/src/tirx/ir/tir_visitor_with_path.cc index 857ccca08eff..512225344959 100644 --- a/src/tirx/ir/tir_visitor_with_path.cc +++ b/src/tirx/ir/tir_visitor_with_path.cc @@ -35,9 +35,8 @@ namespace tvm { namespace tirx { -using AccessPath = ffi::reflection::AccessPath; -void TIRVisitorWithPath::Visit(const IRModule& mod, AccessPath path) { +void TIRVisitorWithPath::Visit(const IRModule& mod, ffi::reflection::AccessPath path) { // To ensure deterministic order of visits, sort the GlobalVar first // by visibility (public then private), then alphabetically by name. std::vector gvars; @@ -76,7 +75,7 @@ void TIRVisitorWithPath::Visit(const IRModule& mod, AccessPath path) { while (context.size()) context.pop_back(); } -void TIRVisitorWithPath::Visit(const PrimFunc& func, AccessPath path) { +void TIRVisitorWithPath::Visit(const PrimFunc& func, ffi::reflection::AccessPath path) { // The implicit definitions from a PrimFunc::buffer_map are pretty // weird. They only apply if no previous definition of that // variable has occurred. Therefore, to ensure that we only avoid @@ -115,25 +114,25 @@ void TIRVisitorWithPath::Visit(const PrimFunc& func, AccessPath path) { while (context.size()) context.pop_back(); } -void TIRVisitorWithPath::EnterDef(const IterVar& iter_var, AccessPath path) { +void TIRVisitorWithPath::EnterDef(const IterVar& iter_var, ffi::reflection::AccessPath path) { if (iter_var->dom.defined()) { Visit(iter_var->dom, path->Attr("dom")); } EnterDef(iter_var->var, path->Attr("var")); } -void TIRVisitorWithPath::ExitDef(const IterVar& iter_var, AccessPath path) { +void TIRVisitorWithPath::ExitDef(const IterVar& iter_var, ffi::reflection::AccessPath path) { ExitDef(iter_var->var, path->Attr("var")); } -void TIRVisitorWithPath::EnterDef(const Buffer& buffer, AccessPath path) { +void TIRVisitorWithPath::EnterDef(const Buffer& buffer, ffi::reflection::AccessPath path) { // Defining a buffer counts as using all parameters in the buffer // (e.g. shape/strides). VisitBufferDef(buffer, path); } -void TIRVisitorWithPath::ExitDef(const Buffer& buffer, AccessPath path) {} +void TIRVisitorWithPath::ExitDef(const Buffer& buffer, ffi::reflection::AccessPath path) {} -void TIRVisitorWithPath::VisitBufferDef(const Buffer& buffer, AccessPath path) { +void TIRVisitorWithPath::VisitBufferDef(const Buffer& buffer, ffi::reflection::AccessPath path) { Visit(buffer->data, path->Attr("data")); Visit(buffer->shape, path->Attr("shape")); Visit(buffer->strides, path->Attr("strides")); @@ -145,14 +144,14 @@ void TIRVisitorWithPath::VisitBufferDef(const Buffer& buffer, AccessPath path) { // VisitBufferDef/EnterDef. Re-visiting at use sites would require those // variables to be in scope at every use, which may not hold when buffers // are allocated in a different scope than where they are used. -void TIRVisitorWithPath::VisitBufferUse(const Buffer& buffer, AccessPath path) {} +void TIRVisitorWithPath::VisitBufferUse(const Buffer& buffer, ffi::reflection::AccessPath path) {} -void TIRVisitorWithPath::Visit(const BufferRegion& region, AccessPath path) { +void TIRVisitorWithPath::Visit(const BufferRegion& region, ffi::reflection::AccessPath path) { VisitBufferUse(region->buffer, path->Attr("buffer")); Visit(region->region, path->Attr("region")); } -void TIRVisitorWithPath::Visit(const MatchBufferRegion& match, AccessPath path) { +void TIRVisitorWithPath::Visit(const MatchBufferRegion& match, ffi::reflection::AccessPath path) { Visit(match->source, path->Attr("source")); // MatchBufferRegion define the match->buffer, but do not own the @@ -160,26 +159,26 @@ void TIRVisitorWithPath::Visit(const MatchBufferRegion& match, AccessPath path) // definitions are handled in the BlockNode visitor. } -void TIRVisitorWithPath::Visit(const IterVar& iter_var, AccessPath path) { +void TIRVisitorWithPath::Visit(const IterVar& iter_var, ffi::reflection::AccessPath path) { if (iter_var->dom.defined()) { Visit(iter_var->dom, path->Attr("dom")); } Visit(iter_var->var, path->Attr("var")); } -void TIRVisitorWithPath::Visit(const Range& range, AccessPath path) { +void TIRVisitorWithPath::Visit(const Range& range, ffi::reflection::AccessPath path) { Visit(range->min, path->Attr("min")); Visit(range->extent, path->Attr("extent")); } -void TIRVisitorWithPath::VisitStmt_(const BindNode* op, AccessPath path) { +void TIRVisitorWithPath::VisitStmt_(const BindNode* op, ffi::reflection::AccessPath path) { Visit(op->value, path->Attr("value")); // Push the Bind's var definition into the current scope. // The def lives until the enclosing scope (body-carrying stmt) exits. bind_scope_.Current().push_back(WithDef(op->var, path->Attr("var"))); } -void TIRVisitorWithPath::VisitStmt_(const AttrStmtNode* op, AccessPath path) { +void TIRVisitorWithPath::VisitStmt_(const AttrStmtNode* op, ffi::reflection::AccessPath path) { Visit(op->value, path->Attr("value")); std::vector, DefContext, DefContext>> context; @@ -200,19 +199,19 @@ void TIRVisitorWithPath::VisitStmt_(const AttrStmtNode* op, AccessPath path) { } } -void TIRVisitorWithPath::VisitStmt_(const ForNode* op, AccessPath path) { +void TIRVisitorWithPath::VisitStmt_(const ForNode* op, ffi::reflection::AccessPath path) { Visit(op->min, path->Attr("min")); Visit(op->extent, path->Attr("extent")); auto context = WithDef(op->loop_var, path->Attr("loop_var")); bind_scope_.WithNewScope([&]() { Visit(op->body, path->Attr("body")); }); } -void TIRVisitorWithPath::VisitStmt_(const WhileNode* op, AccessPath path) { +void TIRVisitorWithPath::VisitStmt_(const WhileNode* op, ffi::reflection::AccessPath path) { Visit(op->condition, path->Attr("condition")); bind_scope_.WithNewScope([&]() { Visit(op->body, path->Attr("body")); }); } -void TIRVisitorWithPath::VisitStmt_(const AllocBufferNode* op, AccessPath path) { +void TIRVisitorWithPath::VisitStmt_(const AllocBufferNode* op, ffi::reflection::AccessPath path) { // AllocBuffer both allocates the data variable and declares the buffer. // Push definitions into the current scope so they are visible to subsequent siblings. auto buf_path = path->Attr("buffer"); @@ -220,41 +219,41 @@ void TIRVisitorWithPath::VisitStmt_(const AllocBufferNode* op, AccessPath path) bind_scope_.Current().push_back(WithDef(op->buffer, buf_path)); } -void TIRVisitorWithPath::VisitStmt_(const DeclBufferNode* op, AccessPath path) { +void TIRVisitorWithPath::VisitStmt_(const DeclBufferNode* op, ffi::reflection::AccessPath path) { // Push buffer definition into the current scope so it is visible to subsequent siblings. bind_scope_.Current().push_back(WithDef(op->buffer, path->Attr("buffer"))); } -void TIRVisitorWithPath::VisitStmt_(const BufferStoreNode* op, AccessPath path) { +void TIRVisitorWithPath::VisitStmt_(const BufferStoreNode* op, ffi::reflection::AccessPath path) { Visit(op->value, path->Attr("value")); VisitBufferUse(op->buffer, path->Attr("buffer")); Visit(op->indices, path->Attr("indices")); } -void TIRVisitorWithPath::VisitStmt_(const IfThenElseNode* op, AccessPath path) { +void TIRVisitorWithPath::VisitStmt_(const IfThenElseNode* op, ffi::reflection::AccessPath path) { Visit(op->condition, path->Attr("condition")); bind_scope_.WithNewScope([&]() { Visit(op->then_case, path->Attr("then_case")); }); bind_scope_.WithNewScope([&]() { Visit(op->else_case, path->Attr("else_case")); }); } -void TIRVisitorWithPath::VisitStmt_(const AssertStmtNode* op, AccessPath path) { +void TIRVisitorWithPath::VisitStmt_(const AssertStmtNode* op, ffi::reflection::AccessPath path) { Visit(op->condition, path->Attr("condition")); Visit(op->error_kind, path->Attr("error_kind")); Visit(op->message_parts, path->Attr("message_parts")); } -void TIRVisitorWithPath::VisitStmt_(const SeqStmtNode* op, AccessPath path) { +void TIRVisitorWithPath::VisitStmt_(const SeqStmtNode* op, ffi::reflection::AccessPath path) { auto seq_path = path->Attr("seq"); for (size_t i = 0; i < op->seq.size(); i++) { Visit(op->seq[i], seq_path->ArrayItem(i)); } } -void TIRVisitorWithPath::VisitStmt_(const EvaluateNode* op, AccessPath path) { +void TIRVisitorWithPath::VisitStmt_(const EvaluateNode* op, ffi::reflection::AccessPath path) { Visit(op->value, path->Attr("value")); } -void TIRVisitorWithPath::VisitStmt_(const SBlockNode* op, AccessPath path) { +void TIRVisitorWithPath::VisitStmt_(const SBlockNode* op, ffi::reflection::AccessPath path) { std::vector, DefContext, DefContext>> context; { @@ -300,34 +299,34 @@ void TIRVisitorWithPath::VisitStmt_(const SBlockNode* op, AccessPath path) { while (context.size()) context.pop_back(); } -void TIRVisitorWithPath::VisitStmt_(const SBlockRealizeNode* op, AccessPath path) { +void TIRVisitorWithPath::VisitStmt_(const SBlockRealizeNode* op, ffi::reflection::AccessPath path) { Visit(op->iter_values, path->Attr("iter_values")); Visit(op->predicate, path->Attr("predicate")); Visit(op->block, path->Attr("block")); } -void TIRVisitorWithPath::VisitExpr_(const VarNode* op, AccessPath path) {} +void TIRVisitorWithPath::VisitExpr_(const VarNode* op, ffi::reflection::AccessPath path) {} -void TIRVisitorWithPath::VisitExpr_(const SizeVarNode* op, AccessPath path) { +void TIRVisitorWithPath::VisitExpr_(const SizeVarNode* op, ffi::reflection::AccessPath path) { VisitExpr_(static_cast(op), path); } -void TIRVisitorWithPath::VisitExpr_(const BufferLoadNode* op, AccessPath path) { +void TIRVisitorWithPath::VisitExpr_(const BufferLoadNode* op, ffi::reflection::AccessPath path) { VisitBufferUse(op->buffer, path->Attr("buffer")); Visit(op->indices, path->Attr("indices")); } -void TIRVisitorWithPath::VisitExpr_(const ProducerLoadNode* op, AccessPath path) { +void TIRVisitorWithPath::VisitExpr_(const ProducerLoadNode* op, ffi::reflection::AccessPath path) { Visit(op->indices, path->Attr("indices")); } -void TIRVisitorWithPath::VisitExpr_(const LetNode* op, AccessPath path) { +void TIRVisitorWithPath::VisitExpr_(const LetNode* op, ffi::reflection::AccessPath path) { Visit(op->value, path->Attr("value")); auto context = WithDef(op->var, path->Attr("var")); Visit(op->body, path->Attr("body")); } -void TIRVisitorWithPath::VisitExpr_(const CallNode* op, AccessPath path) { +void TIRVisitorWithPath::VisitExpr_(const CallNode* op, ffi::reflection::AccessPath path) { if (auto gvar = op->op.as()) { Visit(gvar.value(), path->Attr("op")); } @@ -335,7 +334,7 @@ void TIRVisitorWithPath::VisitExpr_(const CallNode* op, AccessPath path) { } #define DEFINE_BINOP_VISIT_(OP) \ - void TIRVisitorWithPath::VisitExpr_(const OP* op, AccessPath path) { \ + void TIRVisitorWithPath::VisitExpr_(const OP* op, ffi::reflection::AccessPath path) { \ Visit(op->a, path->Attr("a")); \ Visit(op->b, path->Attr("b")); \ } @@ -360,43 +359,43 @@ DEFINE_BINOP_VISIT_(OrNode); #undef DEFINE_BINOP_VISIT_ -void TIRVisitorWithPath::VisitExpr_(const IntImmNode* op, AccessPath path) {} -void TIRVisitorWithPath::VisitExpr_(const FloatImmNode* op, AccessPath path) {} -void TIRVisitorWithPath::VisitExpr_(const StringImmNode* op, AccessPath path) {} +void TIRVisitorWithPath::VisitExpr_(const IntImmNode* op, ffi::reflection::AccessPath path) {} +void TIRVisitorWithPath::VisitExpr_(const FloatImmNode* op, ffi::reflection::AccessPath path) {} +void TIRVisitorWithPath::VisitExpr_(const StringImmNode* op, ffi::reflection::AccessPath path) {} -void TIRVisitorWithPath::VisitExpr_(const ReduceNode* op, AccessPath path) { +void TIRVisitorWithPath::VisitExpr_(const ReduceNode* op, ffi::reflection::AccessPath path) { Visit(op->axis, path->Attr("axis")); Visit(op->source, path->Attr("source")); Visit(op->init, path->Attr("init")); Visit(op->condition, path->Attr("condition")); } -void TIRVisitorWithPath::VisitExpr_(const CastNode* op, AccessPath path) { +void TIRVisitorWithPath::VisitExpr_(const CastNode* op, ffi::reflection::AccessPath path) { Visit(op->value, path->Attr("value")); } -void TIRVisitorWithPath::VisitExpr_(const NotNode* op, AccessPath path) { +void TIRVisitorWithPath::VisitExpr_(const NotNode* op, ffi::reflection::AccessPath path) { Visit(op->a, path->Attr("a")); } -void TIRVisitorWithPath::VisitExpr_(const SelectNode* op, AccessPath path) { +void TIRVisitorWithPath::VisitExpr_(const SelectNode* op, ffi::reflection::AccessPath path) { Visit(op->condition, path->Attr("condition")); Visit(op->true_value, path->Attr("true_value")); Visit(op->false_value, path->Attr("false_value")); } -void TIRVisitorWithPath::VisitExpr_(const RampNode* op, AccessPath path) { +void TIRVisitorWithPath::VisitExpr_(const RampNode* op, ffi::reflection::AccessPath path) { Visit(op->base, path->Attr("base")); Visit(op->stride, path->Attr("stride")); Visit(op->lanes, path->Attr("lanes")); } -void TIRVisitorWithPath::VisitExpr_(const ShuffleNode* op, AccessPath path) { +void TIRVisitorWithPath::VisitExpr_(const ShuffleNode* op, ffi::reflection::AccessPath path) { Visit(op->indices, path->Attr("indices")); Visit(op->vectors, path->Attr("vectors")); } -void TIRVisitorWithPath::VisitExpr_(const BroadcastNode* op, AccessPath path) { +void TIRVisitorWithPath::VisitExpr_(const BroadcastNode* op, ffi::reflection::AccessPath path) { Visit(op->value, path->Attr("value")); Visit(op->lanes, path->Attr("lanes")); } diff --git a/src/tirx/op/builtin.cc b/src/tirx/op/builtin.cc index 4355583d796b..3ba16d9f9cf3 100644 --- a/src/tirx/op/builtin.cc +++ b/src/tirx/op/builtin.cc @@ -211,11 +211,11 @@ TIR_DEFINE_BUILTIN_FUNC(tvm_stack_make_array) // When num_inputs are not set, the function is assumed to be variable length. TIR_DEFINE_BUILTIN_FUNC(tvm_call_packed) .set_attr("TCallEffectKind", Integer(CallEffectKind::kOpaque)) - .set_attr("TScriptPrinterName", ffi::String("call_packed"), /*plevel=*/20); + .set_attr("TScriptPrinterName", ffi::String("call_packed"), /*plevel=*/20); TIR_DEFINE_BUILTIN_FUNC(tvm_call_cpacked) .set_attr("TCallEffectKind", Integer(CallEffectKind::kOpaque)) - .set_attr("TScriptPrinterName", ffi::String("call_cpacked"), /*plevel=*/20); + .set_attr("TScriptPrinterName", ffi::String("call_cpacked"), /*plevel=*/20); TIR_DEFINE_BUILTIN_FUNC(tvm_call_trace_packed) .set_attr("TCallEffectKind", Integer(CallEffectKind::kOpaque)); @@ -226,12 +226,12 @@ TIR_DEFINE_BUILTIN_FUNC(tvm_thread_invariant) TIR_DEFINE_BUILTIN_FUNC(tvm_call_packed_lowered) .set_attr("TCallEffectKind", Integer(CallEffectKind::kOpaque)) - .set_attr("TScriptPrinterName", ffi::String("call_packed_lowered"), + .set_attr("TScriptPrinterName", ffi::String("call_packed_lowered"), /*plevel=*/20); TIR_DEFINE_BUILTIN_FUNC(tvm_call_cpacked_lowered) .set_attr("TCallEffectKind", Integer(CallEffectKind::kOpaque)) - .set_attr("TScriptPrinterName", ffi::String("call_cpacked_lowered"), + .set_attr("TScriptPrinterName", ffi::String("call_cpacked_lowered"), /*plevel=*/20); TIR_DEFINE_BUILTIN_FUNC(tvm_call_trace_packed_lowered) @@ -410,7 +410,7 @@ TIR_DEFINE_BUILTIN_FUNC(anylist_getitem) TIR_DEFINE_BUILTIN_FUNC(anylist_resetitem) .set_attr("TCallEffectKind", Integer(CallEffectKind::kOpaque)) - .set_attr("TGlobalSymbol", "TVMBackendAnyListResetItem"); + .set_attr("TGlobalSymbol", "TVMBackendAnyListResetItem"); TIR_DEFINE_BUILTIN_FUNC(anylist_setitem_call_packed) .set_attr("TCallEffectKind", Integer(CallEffectKind::kOpaque)); diff --git a/src/tirx/op/op.cc b/src/tirx/op/op.cc index f1c9c8a9b507..fa3953664bed 100644 --- a/src/tirx/op/op.cc +++ b/src/tirx/op/op.cc @@ -29,6 +29,7 @@ #include #include #include +#include #include // Centralized header for constant folders. @@ -1145,12 +1146,12 @@ TVM_TIR_REGISTER_PURE_BINARY_OP("ldexp"); TVM_TIR_REGISTER_OP("TVMBackendAllocWorkspace") .set_num_inputs(5) - .set_attr("TGlobalSymbol", "TVMBackendAllocWorkspace") + .set_attr("TGlobalSymbol", "TVMBackendAllocWorkspace") .set_attr("TCallEffectKind", Integer(CallEffectKind::kOpaque)); TVM_TIR_REGISTER_OP("TVMBackendFreeWorkspace") .set_num_inputs(3) - .set_attr("TGlobalSymbol", "TVMBackendFreeWorkspace") + .set_attr("TGlobalSymbol", "TVMBackendFreeWorkspace") .set_attr("TCallEffectKind", Integer(CallEffectKind::kOpaque)); // expose basic functions to node namespace diff --git a/src/tirx/op/runtime.cc b/src/tirx/op/runtime.cc index e013b21d6676..148c2b9c132e 100644 --- a/src/tirx/op/runtime.cc +++ b/src/tirx/op/runtime.cc @@ -29,12 +29,12 @@ namespace tirx { TVM_REGISTER_OP("tirx.TVMBackendAnyListSetPackedArg") .set_num_inputs(5) - .set_attr("TGlobalSymbol", "TVMBackendAnyListSetPackedArg") + .set_attr("TGlobalSymbol", "TVMBackendAnyListSetPackedArg") .set_attr("TCallEffectKind", Integer(CallEffectKind::kOpaque)); TVM_REGISTER_OP("tirx.TVMBackendAnyListMoveFromPackedReturn") .set_num_inputs(3) - .set_attr("TGlobalSymbol", "TVMBackendAnyListMoveFromPackedReturn") + .set_attr("TGlobalSymbol", "TVMBackendAnyListMoveFromPackedReturn") .set_attr("TCallEffectKind", Integer(CallEffectKind::kOpaque)); } // namespace tirx diff --git a/src/tirx/script/builder/ir.cc b/src/tirx/script/builder/ir.cc index 7044cfe7e390..6dc316590c01 100644 --- a/src/tirx/script/builder/ir.cc +++ b/src/tirx/script/builder/ir.cc @@ -23,6 +23,7 @@ #include #include #include +#include #include "./utils.h" diff --git a/src/tirx/script/printer/block.cc b/src/tirx/script/printer/block.cc index 6c86d68ff5f4..3f7c86c49ff6 100644 --- a/src/tirx/script/printer/block.cc +++ b/src/tirx/script/printer/block.cc @@ -22,14 +22,14 @@ namespace tvm { namespace script { namespace printer { -Doc PrintBlock(IRDocsifier d, tirx::SBlock block, AccessPath block_p, // +Doc PrintBlock(IRDocsifier d, tirx::SBlock block, ffi::reflection::AccessPath block_p, // ffi::Optional opt_realize, - ffi::Optional opt_realize_p) { + ffi::Optional opt_realize_p) { With frame(d, block); TVM_FFI_ICHECK_EQ(opt_realize.defined(), opt_realize_p.defined()); const tirx::SBlockRealizeNode* realize = opt_realize.defined() ? opt_realize.value().get() : nullptr; - AccessPath realize_p = *opt_realize_p; + ffi::reflection::AccessPath realize_p = *opt_realize_p; // Step 1. Handle block var and block bindings // Step 1.1. Obtain all loop var defined along path std::unordered_map loop_vars; @@ -69,7 +69,7 @@ Doc PrintBlock(IRDocsifier d, tirx::SBlock block, AccessPath block_p, // auto print_single_iter_var = [&](int i) { tirx::IterVar iter_var = block->iter_vars[i]; - AccessPath iter_var_p = block_p->Attr("iter_var")->ArrayItem(i); + ffi::reflection::AccessPath iter_var_p = block_p->Attr("iter_var")->ArrayItem(i); ExprDoc rhs = TIR(d, "axis"); if (iter_var->iter_type == tirx::IterVarType::kDataPar) { rhs = rhs->Attr("spatial"); @@ -120,10 +120,10 @@ Doc PrintBlock(IRDocsifier d, tirx::SBlock block, AccessPath block_p, // lhs.reserve(m); loop_var_doc.reserve(m); std::string binding_type = ""; - ffi::Array binding_paths; + ffi::Array binding_paths; for (int i : remap_vars_indices) { tirx::IterVar iter_var = block->iter_vars[i]; - AccessPath iter_var_p = block_p->Attr("iter_vars")->ArrayItem(i); + ffi::reflection::AccessPath iter_var_p = block_p->Attr("iter_vars")->ArrayItem(i); lhs.push_back(DefineVar(iter_var->var, *frame, d)); loop_var_doc.push_back(d->AsDoc(realize->iter_values[i], realize_p->Attr("iter_values")->ArrayItem(i))); @@ -180,7 +180,7 @@ Doc PrintBlock(IRDocsifier d, tirx::SBlock block, AccessPath block_p, // // Step 5. Handle `alloc_buffer` for (int i = 0, n = block->alloc_buffers.size(); i < n; ++i) { tirx::Buffer buffer = block->alloc_buffers[i]; - AccessPath buffer_p = block_p->Attr("alloc_buffers")->ArrayItem(i); + ffi::reflection::AccessPath buffer_p = block_p->Attr("alloc_buffers")->ArrayItem(i); IdDoc lhs = DefineBuffer(buffer, *frame, d); ExprDoc rhs = BufferDecl(buffer, "sblock_alloc_buffer", {}, buffer_p, *frame, d, BufferVarDefinition::DataPointer); @@ -189,7 +189,7 @@ Doc PrintBlock(IRDocsifier d, tirx::SBlock block, AccessPath block_p, // // Step 6. Handle `match_buffer` for (int i = 0, n = block->match_buffers.size(); i < n; ++i) { tirx::MatchBufferRegion buffer_region = block->match_buffers[i]; - AccessPath buffer_region_p = block_p->Attr("match_buffers")->ArrayItem(i); + ffi::reflection::AccessPath buffer_region_p = block_p->Attr("match_buffers")->ArrayItem(i); StmtDoc doc = d->AsDoc(buffer_region, buffer_region_p); (*frame)->stmts.push_back(doc); } @@ -218,7 +218,7 @@ Doc PrintBlock(IRDocsifier d, tirx::SBlock block, AccessPath block_p, // TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) .set_dispatch( - "", [](tirx::SBlockRealize realize, AccessPath p, IRDocsifier d) -> Doc { + "", [](tirx::SBlockRealize realize, ffi::reflection::AccessPath p, IRDocsifier d) -> Doc { Doc doc = PrintBlock(d, realize->block, p->Attr("block"), realize, p); // since we do not have d->AsDoc for realize->block, // we should add possible doc decoration manually. @@ -227,7 +227,7 @@ TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) }); TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) - .set_dispatch("", [](tirx::SBlock block, AccessPath p, IRDocsifier d) -> Doc { + .set_dispatch("", [](tirx::SBlock block, ffi::reflection::AccessPath p, IRDocsifier d) -> Doc { return PrintBlock(d, block, p, std::nullopt, std::nullopt); }); diff --git a/src/tirx/script/printer/buffer.cc b/src/tirx/script/printer/buffer.cc index eb34153557ed..53149cdd7041 100644 --- a/src/tirx/script/printer/buffer.cc +++ b/src/tirx/script/printer/buffer.cc @@ -24,7 +24,7 @@ namespace tvm { namespace script { namespace printer { -ffi::Map BufferAttrs(tirx::Buffer buffer, const AccessPath& buffer_p, +ffi::Map BufferAttrs(tirx::Buffer buffer, const ffi::reflection::AccessPath& buffer_p, const Frame& frame, const IRDocsifier& d, BufferVarDefinition var_definitions) { using tvm::tirx::Var; @@ -53,14 +53,14 @@ ffi::Map BufferAttrs(tirx::Buffer buffer, const AccessPath auto is_new_var = [&](const PrimExpr& e) { return e->IsInstance() && !d->IsVarDefined(e); }; - auto add_out_of_line_var_def = [&](const Var& var, const AccessPath& var_p) { + auto add_out_of_line_var_def = [&](const Var& var, const ffi::reflection::AccessPath& var_p) { TVM_FFI_ICHECK(!d->IsVarDefined(var)); ExprDoc lhs = DefineVar(var, frame, d); lhs->source_paths.push_back(var_p); var_def_lhs.push_back(lhs); var_def_rhs.push_back(PrintVarCreation(var, var_p, d)); }; - auto try_inline_def = [&](const PrimExpr& e, const AccessPath& e_p, + auto try_inline_def = [&](const PrimExpr& e, const ffi::reflection::AccessPath& e_p, std::function inline_f) { TVM_FFI_ICHECK(is_new_var(e)); Var var = Downcast(e); @@ -75,13 +75,13 @@ ffi::Map BufferAttrs(tirx::Buffer buffer, const AccessPath // Step 1. Handle `buffer.shape` { const ffi::Array& shape = buffer->shape; - AccessPath shape_p = buffer_p->Attr("shape"); + ffi::reflection::AccessPath shape_p = buffer_p->Attr("shape"); int n = shape.size(); ffi::Array results; results.reserve(n); for (int i = 0; i < n; ++i) { PrimExpr e = shape[i]; - AccessPath e_p = shape_p->ArrayItem(i); + ffi::reflection::AccessPath e_p = shape_p->ArrayItem(i); if (is_new_var(e)) { add_out_of_line_var_def(Downcast(e), e_p); } @@ -110,13 +110,13 @@ ffi::Map BufferAttrs(tirx::Buffer buffer, const AccessPath // Step 4. Handle `buffer.strides` if (!buffer->strides.empty()) { const ffi::Array& strides = buffer->strides; - AccessPath strides_p = buffer_p->Attr("strides"); + ffi::reflection::AccessPath strides_p = buffer_p->Attr("strides"); int n = strides.size(); ffi::Array results; results.reserve(n); for (int i = 0; i < n; ++i) { PrimExpr e = strides[i]; - AccessPath e_p = strides_p->ArrayItem(i); + ffi::reflection::AccessPath e_p = strides_p->ArrayItem(i); if (is_new_var(e)) { if (try_inline_def(e, e_p, [=]() { return d->AsDoc(buffer, buffer_p) @@ -203,14 +203,14 @@ ExprDoc BufferCall(const ExprDoc& prefix, const ffi::Map& } ExprDoc BufferDecl(const tirx::Buffer& buffer, const ffi::String& method, - const ffi::Array& args, const AccessPath& p, const Frame& frame, + const ffi::Array& args, const ffi::reflection::AccessPath& p, const Frame& frame, const IRDocsifier& d, BufferVarDefinition var_definitions) { return BufferCall(/*prefix=*/TIR(d, method), /*attrs=*/BufferAttrs(buffer, p, frame, d, var_definitions), /*args=*/args); } -ExprDoc BufferAttn(const tirx::Buffer& buffer, const AccessPath& p, const Frame& frame, +ExprDoc BufferAttn(const tirx::Buffer& buffer, const ffi::reflection::AccessPath& p, const Frame& frame, const IRDocsifier& d) { ffi::Map attrs = BufferAttrs(buffer, p, frame, d, BufferVarDefinition::DataPointer); @@ -220,7 +220,7 @@ ExprDoc BufferAttn(const tirx::Buffer& buffer, const AccessPath& p, const Frame& return TIR(d, "Buffer")->Call({shape, dtype}, {}, {}); } -ffi::Array BufferIndices(const ffi::Array& indices, const AccessPath& p, +ffi::Array BufferIndices(const ffi::Array& indices, const ffi::reflection::AccessPath& p, const IRDocsifier& d) { int n = indices.size(); ffi::Array indices_doc; @@ -228,8 +228,8 @@ ffi::Array BufferIndices(const ffi::Array& indices, const AccessP for (int i = 0; i < n; ++i) { if (const auto* ramp = indices[i].as()) { if (const auto* stride = ramp->stride.as()) { - AccessPath ramp_p = p->Attr("indices")->ArrayItem(i); - AccessPath stride_p = ramp_p->Attr("stride"); + ffi::reflection::AccessPath ramp_p = p->Attr("indices")->ArrayItem(i); + ffi::reflection::AccessPath stride_p = ramp_p->Attr("stride"); ExprDoc start = d->AsDoc(ramp->base, // ramp_p->Attr("base")); ExprDoc stop = d->AsDoc(ramp->base + ramp->lanes * ramp->stride, // @@ -247,14 +247,14 @@ ffi::Array BufferIndices(const ffi::Array& indices, const AccessP return indices_doc; } -ffi::Array BufferSlices(const ffi::Array& region, const AccessPath& p, +ffi::Array BufferSlices(const ffi::Array& region, const ffi::reflection::AccessPath& p, const IRDocsifier& d) { int n = region.size(); ffi::Array indices; indices.reserve(n); for (int i = 0; i < n; ++i) { Range range = region[i]; - AccessPath range_p = p->ArrayItem(i); + ffi::reflection::AccessPath range_p = p->ArrayItem(i); ExprDoc min = d->AsDoc(range->min, range_p->Attr("min")); if (tirx::is_one(range->extent)) { indices.push_back(min); @@ -268,14 +268,14 @@ ffi::Array BufferSlices(const ffi::Array& region, const AccessPath& TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) .set_dispatch( - "", [](tirx::BufferRegion buffer_region, AccessPath p, IRDocsifier d) -> Doc { + "", [](tirx::BufferRegion buffer_region, ffi::reflection::AccessPath p, IRDocsifier d) -> Doc { ExprDoc prefix = d->AsDoc(buffer_region->buffer, p->Attr("buffer")); return prefix[BufferSlices(buffer_region->region, p->Attr("region"), d)]; }); TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) .set_dispatch( // - "", [](tirx::BufferStore store, AccessPath p, IRDocsifier d) -> Doc { + "", [](tirx::BufferStore store, ffi::reflection::AccessPath p, IRDocsifier d) -> Doc { ExprDoc buffer = d->AsDoc(store->buffer, p->Attr("buffer")); ExprDoc value = d->AsDoc(store->value, p->Attr("value")); @@ -294,7 +294,7 @@ TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) .set_dispatch( // - "", [](tirx::BufferLoad load, AccessPath p, IRDocsifier d) -> Doc { + "", [](tirx::BufferLoad load, ffi::reflection::AccessPath p, IRDocsifier d) -> Doc { ExprDoc buffer = d->AsDoc(load->buffer, p->Attr("buffer")); // Use .vload(...) syntax when there is a predicate @@ -308,7 +308,7 @@ TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) }); TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) // - .set_dispatch("", [](tirx::Buffer buffer, AccessPath p, IRDocsifier d) -> Doc { + .set_dispatch("", [](tirx::Buffer buffer, ffi::reflection::AccessPath p, IRDocsifier d) -> Doc { if (!d->IsVarDefined(buffer)) { if (ffi::Optional opt_f = FindLowestVarDef(buffer, d)) { ExprDoc lhs = DefineBuffer(buffer, opt_f.value(), d); @@ -326,7 +326,7 @@ TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) // TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) .set_dispatch( - "", [](tirx::MatchBufferRegion stmt, AccessPath p, IRDocsifier d) -> Doc { + "", [](tirx::MatchBufferRegion stmt, ffi::reflection::AccessPath p, IRDocsifier d) -> Doc { Frame frame = d->frames.back(); ExprDoc lhs = DefineBuffer(stmt->buffer, frame, d); ExprDoc src_buffer = d->AsDoc(stmt->source, p->Attr("source")); @@ -337,7 +337,7 @@ TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) .set_dispatch( // - "", [](tirx::ProducerLoad load, AccessPath p, IRDocsifier d) -> Doc { + "", [](tirx::ProducerLoad load, ffi::reflection::AccessPath p, IRDocsifier d) -> Doc { ExprDoc prefix = IdDoc(load->producer->GetNameHint()); return prefix[BufferIndices(load->indices, p->Attr("indices"), d)]; }); diff --git a/src/tirx/script/printer/expr.cc b/src/tirx/script/printer/expr.cc index 6d2e13cbd4b7..6149c208d11f 100644 --- a/src/tirx/script/printer/expr.cc +++ b/src/tirx/script/printer/expr.cc @@ -17,6 +17,7 @@ * under the License. */ #include +#include #include "./utils.h" @@ -24,9 +25,9 @@ namespace tvm { namespace script { namespace printer { -ExprDoc PrintVarCreation(const tirx::Var& var, const AccessPath& var_p, const IRDocsifier& d) { +ExprDoc PrintVarCreation(const tirx::Var& var, const ffi::reflection::AccessPath& var_p, const IRDocsifier& d) { Type type = var->type_annotation; - AccessPath type_p = var_p->Attr("type_annotation"); + ffi::reflection::AccessPath type_p = var_p->Attr("type_annotation"); ExprDoc rhs{ffi::UnsafeInit()}; ffi::Array kwargs_keys; ffi::Array kwargs_values; @@ -64,7 +65,7 @@ ExprDoc PrintVarCreation(const tirx::Var& var, const AccessPath& var_p, const IR return rhs; } -Doc PrintVar(const tirx::Var& var, const AccessPath& var_p, const IRDocsifier& d) { +Doc PrintVar(const tirx::Var& var, const ffi::reflection::AccessPath& var_p, const IRDocsifier& d) { if (!d->IsVarDefined(var)) { if (ffi::Optional opt_f = FindLowestVarDef(var, d)) { ExprDoc lhs = DefineVar(var, opt_f.value(), d); @@ -82,17 +83,17 @@ Doc PrintVar(const tirx::Var& var, const AccessPath& var_p, const IRDocsifier& d } TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) // - .set_dispatch("", [](tirx::Var var, AccessPath p, IRDocsifier d) -> Doc { + .set_dispatch("", [](tirx::Var var, ffi::reflection::AccessPath p, IRDocsifier d) -> Doc { return PrintVar(var, p, d); }); TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) // - .set_dispatch("", [](tirx::SizeVar var, AccessPath p, IRDocsifier d) -> Doc { + .set_dispatch("", [](tirx::SizeVar var, ffi::reflection::AccessPath p, IRDocsifier d) -> Doc { return PrintVar(var, p, d); }); TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) - .set_dispatch("", [](tirx::IterVar var, AccessPath var_p, IRDocsifier d) -> Doc { + .set_dispatch("", [](tirx::IterVar var, ffi::reflection::AccessPath var_p, IRDocsifier d) -> Doc { return TIR(d, "iter_var") ->Call({ d->AsDoc(var->var, var_p->Attr("var")), @@ -103,7 +104,7 @@ TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) }); TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) - .set_dispatch("", [](tirx::Not node, AccessPath p, IRDocsifier d) -> Doc { + .set_dispatch("", [](tirx::Not node, ffi::reflection::AccessPath p, IRDocsifier d) -> Doc { ExprDoc a = d->AsDoc(node->a, p->Attr("a")); if (a->IsInstance()) { return TIR(d, "Not")->Call({a}); @@ -112,7 +113,7 @@ TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) }); TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) - .set_dispatch("", [](tirx::StringImm s, AccessPath p, IRDocsifier d) -> Doc { + .set_dispatch("", [](tirx::StringImm s, ffi::reflection::AccessPath p, IRDocsifier d) -> Doc { if (HasMultipleLines(s->value)) { return d->AddMetadata(s); } else { @@ -121,14 +122,14 @@ TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) }); TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) - .set_dispatch("", [](tirx::Cast cast, AccessPath p, IRDocsifier d) -> Doc { + .set_dispatch("", [](tirx::Cast cast, ffi::reflection::AccessPath p, IRDocsifier d) -> Doc { ExprDoc dtype = LiteralDoc::DataType(cast->dtype, p->Attr("dtype")); ExprDoc value = d->AsDoc(cast->value, p->Attr("value")); return TIR(d, "Cast")->Call({dtype, value}); }); TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) - .set_dispatch("", [](tirx::Select select, AccessPath p, IRDocsifier d) -> Doc { + .set_dispatch("", [](tirx::Select select, ffi::reflection::AccessPath p, IRDocsifier d) -> Doc { return TIR(d, "Select") ->Call({ d->AsDoc(select->condition, p->Attr("condition")), @@ -138,7 +139,7 @@ TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) }); TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) - .set_dispatch("", [](tirx::Ramp ramp, AccessPath ramp_p, IRDocsifier d) -> Doc { + .set_dispatch("", [](tirx::Ramp ramp, ffi::reflection::AccessPath ramp_p, IRDocsifier d) -> Doc { return TIR(d, "Ramp")->Call({ d->AsDoc(ramp->base, ramp_p->Attr("base")), d->AsDoc(ramp->stride, ramp_p->Attr("stride")), @@ -148,7 +149,7 @@ TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) .set_dispatch("", - [](tirx::Broadcast bc, AccessPath bc_p, IRDocsifier d) -> Doc { + [](tirx::Broadcast bc, ffi::reflection::AccessPath bc_p, IRDocsifier d) -> Doc { return TIR(d, "Broadcast") ->Call({ d->AsDoc(bc->value, bc_p->Attr("value")), @@ -158,7 +159,7 @@ TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) .set_dispatch( // - "", [](tirx::Shuffle shuffle, AccessPath p, IRDocsifier d) -> Doc { + "", [](tirx::Shuffle shuffle, ffi::reflection::AccessPath p, IRDocsifier d) -> Doc { return TIR(d, "Shuffle") ->Call({ d->AsDoc(shuffle->vectors, p->Attr("vectors")), @@ -168,7 +169,7 @@ TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) .set_dispatch( // - "", [](tirx::CommReducer r, AccessPath p, IRDocsifier d) -> Doc { + "", [](tirx::CommReducer r, ffi::reflection::AccessPath p, IRDocsifier d) -> Doc { TVM_FFI_ICHECK_EQ(r->lhs.size(), r->rhs.size()); ffi::Optional lambda; { @@ -199,8 +200,8 @@ TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) }); LambdaDoc PrintIndexMap(const ffi::ObjectRef& map, const ffi::Array& vs, - const AccessPath& vs_p, const ffi::Array& es, - const AccessPath& es_p, const IRDocsifier& d) { + const ffi::reflection::AccessPath& vs_p, const ffi::Array& es, + const ffi::reflection::AccessPath& es_p, const IRDocsifier& d) { With f(d, map); ffi::Array vars; for (int i = 0, l = vs.size(); i < l; ++i) { @@ -215,7 +216,7 @@ LambdaDoc PrintIndexMap(const ffi::ObjectRef& map, const ffi::Array& TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) .set_dispatch( // - "", [](tirx::IndexMap m, AccessPath m_p, IRDocsifier d) -> Doc { + "", [](tirx::IndexMap m, ffi::reflection::AccessPath m_p, IRDocsifier d) -> Doc { LambdaDoc map = PrintIndexMap(m, m->initial_indices, m_p->Attr("initial_indices"), m->final_indices, m_p->Attr("final_indices"), d); if (m->inverse_index_map.defined()) { @@ -231,7 +232,7 @@ TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) }); TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) - .set_dispatch("", [](tirx::Let let, AccessPath p, IRDocsifier d) -> Doc { + .set_dispatch("", [](tirx::Let let, ffi::reflection::AccessPath p, IRDocsifier d) -> Doc { DictDoc where({d->AsDoc(let->var, p->Attr("var"))}, {d->AsDoc(let->value, p->Attr("value"))}); return TIR(d, "Let")->Call({d->AsDoc(let->body, p->Attr("body"))}, // @@ -239,9 +240,9 @@ TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) }); TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) - .set_dispatch("", [](tirx::Call call, AccessPath call_p, IRDocsifier d) -> Doc { - static const OpAttrMap& op_names = - Op::GetAttrMap("TScriptPrinterName"); + .set_dispatch("", [](tirx::Call call, ffi::reflection::AccessPath call_p, IRDocsifier d) -> Doc { + static const OpAttrMap& op_names = + Op::GetAttrMap("TScriptPrinterName"); static const OpAttrMap dtype_locations = Op::GetAttrMap("TScriptDtypePrintLocation"); tirx::ScriptDtypePrintLocation dtype_print_location = tirx::ScriptDtypePrintLocation::kNone; @@ -304,7 +305,7 @@ TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) }); TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) - .set_dispatch("", [](tirx::Reduce r, AccessPath p, IRDocsifier d) -> Doc { + .set_dispatch("", [](tirx::Reduce r, ffi::reflection::AccessPath p, IRDocsifier d) -> Doc { ExprDoc combiner = d->AsDoc(r->combiner, p->Attr("combiner")); ExprDoc source = d->AsDoc(r->source, p->Attr("source")); ExprDoc init = d->AsDoc(r->init, p->Attr("init")); @@ -320,7 +321,7 @@ TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) #define TVM_SCRIPT_PRINTER_DEF_BINARY(NodeType, OpString) \ TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) \ .set_dispatch("", \ - [](tirx::NodeType node, AccessPath p, IRDocsifier d) -> Doc { \ + [](tirx::NodeType node, ffi::reflection::AccessPath p, IRDocsifier d) -> Doc { \ ExprDoc a = d->AsDoc(node->a, p->Attr("a")); \ ExprDoc b = d->AsDoc(node->b, p->Attr("b")); \ return TIR(d, OpString)->Call({a, b}); \ @@ -336,7 +337,7 @@ bool IsNumber(const ExprDoc& e) { } TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) - .set_dispatch("", [](tirx::Div node, AccessPath p, IRDocsifier d) -> Doc { + .set_dispatch("", [](tirx::Div node, ffi::reflection::AccessPath p, IRDocsifier d) -> Doc { ExprDoc a = d->AsDoc(node->a, p->Attr("a")); ExprDoc b = d->AsDoc(node->b, p->Attr("b")); PrimExpr ret = tvm::div(node->a, node->b); @@ -353,7 +354,7 @@ TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) #define TVM_SCRIPT_PRINTER_DEF_BINARY_WITH_SUGAR(NodeType, NodeObj, NodeFunc, OpString, OpKind) \ TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) \ .set_dispatch( \ - "", [](tirx::NodeType node, AccessPath p, IRDocsifier d) -> Doc { \ + "", [](tirx::NodeType node, ffi::reflection::AccessPath p, IRDocsifier d) -> Doc { \ ExprDoc a = d->AsDoc(node->a, p->Attr("a")); \ ExprDoc b = d->AsDoc(node->b, p->Attr("b")); \ PrimExpr ret = tvm::NodeFunc(node->a, node->b); \ diff --git a/src/tirx/script/printer/for_loop.cc b/src/tirx/script/printer/for_loop.cc index 9897dd2189b9..d0ea417a3635 100644 --- a/src/tirx/script/printer/for_loop.cc +++ b/src/tirx/script/printer/for_loop.cc @@ -23,7 +23,7 @@ namespace script { namespace printer { TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) - .set_dispatch("", [](tirx::For loop, AccessPath loop_p, IRDocsifier d) -> Doc { + .set_dispatch("", [](tirx::For loop, ffi::reflection::AccessPath loop_p, IRDocsifier d) -> Doc { // Step 1. Check syntactic sugar: `T.grid` std::vector grid; std::unordered_set grid_loop_vars; diff --git a/src/tirx/script/printer/function.cc b/src/tirx/script/printer/function.cc index a743539c5361..15937d721cb1 100644 --- a/src/tirx/script/printer/function.cc +++ b/src/tirx/script/printer/function.cc @@ -65,7 +65,7 @@ int CountVarOccurrence(const tirx::PrimFunc& f, const tirx::Var& v) { } TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) - .set_dispatch("", [](tirx::PrimFunc func, AccessPath p, IRDocsifier d) -> Doc { + .set_dispatch("", [](tirx::PrimFunc func, ffi::reflection::AccessPath p, IRDocsifier d) -> Doc { With f(d, func); (*f)->AddDispatchToken(d, "tirx"); IdDoc func_name = IdDoc(FindFunctionName(d, func).value_or("main")); @@ -87,12 +87,12 @@ TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) std::unordered_set buffer_inlined; for (int i = 0; i < n_args; ++i) { tirx::Var var = func->params[i]; - AccessPath var_p = p->Attr("params")->ArrayItem(i); + ffi::reflection::AccessPath var_p = p->Attr("params")->ArrayItem(i); if (d->cfg->syntax_sugar && CountVarOccurrence(func, var) == 2 && func->buffer_map.count(var)) { tirx::Buffer buffer = func->buffer_map[var]; if (IsSimpleBuffer(buffer) && buffer_data_counter.at(buffer->data.get()) == 1) { - AccessPath buffer_p = p->Attr("buffer_map")->MapItem(var); + ffi::reflection::AccessPath buffer_p = p->Attr("buffer_map")->MapItem(var); IdDoc lhs = DefineBuffer(buffer, *f, d); ExprDoc annotation = BufferAttn(buffer, buffer_p, *f, d); args.push_back(AssignDoc(lhs, std::nullopt, annotation)); @@ -135,7 +135,7 @@ TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) continue; } ExprDoc param_doc = args[i]->lhs; - AccessPath buffer_p = p->Attr("buffer_map")->MapItem(param); + ffi::reflection::AccessPath buffer_p = p->Attr("buffer_map")->MapItem(param); ExprDoc lhs = DefineBuffer(buffer, *f, d); ExprDoc rhs = BufferDecl(buffer, "match_buffer", {param_doc}, buffer_p, *f, d, BufferVarDefinition::MatchBuffer); @@ -165,12 +165,12 @@ TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) }(); if (d->cfg->syntax_sugar && implicit_root_block) { tirx::SBlock root_block = implicit_root_block.value(); - AccessPath root_block_p = p->Attr("body")->Attr("block"); + ffi::reflection::AccessPath root_block_p = p->Attr("body")->Attr("block"); (*f)->stmts.push_back(CommentDoc("with T.sblock(\"root\"):")); // Handle root block `alloc_buffer` for (int i = 0, n = root_block->alloc_buffers.size(); i < n; ++i) { tirx::Buffer buffer = root_block->alloc_buffers[i]; - AccessPath buffer_p = root_block_p->Attr("alloc_buffers")->ArrayItem(i); + ffi::reflection::AccessPath buffer_p = root_block_p->Attr("alloc_buffers")->ArrayItem(i); IdDoc lhs = DefineBuffer(buffer, *f, d); ExprDoc rhs = BufferDecl(buffer, "sblock_alloc_buffer", {}, buffer_p, *f, d, BufferVarDefinition::DataPointer); @@ -193,7 +193,7 @@ TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) if (!func->attrs.defined() || !func->attrs->dict.count(tvm::attr::kGlobalSymbol)) { ffi::Array pos_args; decorator = decorator->Call(pos_args, {"private"}, - {LiteralDoc::Boolean(true, ffi::Optional())}); + {LiteralDoc::Boolean(true, ffi::Optional())}); } return HeaderWrapper(d, FunctionDoc( @@ -208,7 +208,7 @@ TVM_REGISTER_SCRIPT_AS_REPR(tirx::PrimFuncNode, ReprPrintTIR); TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) .set_dispatch( // - "tirx", [](tvm::GlobalVar n, AccessPath n_p, IRDocsifier d) -> Doc { // + "tirx", [](tvm::GlobalVar n, ffi::reflection::AccessPath n_p, IRDocsifier d) -> Doc { // if (ffi::Optional doc = d->GetVarDoc(n)) { return doc.value(); } else { @@ -220,7 +220,7 @@ TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) .set_dispatch( // - "tirx", [](tvm::IRModule mod, AccessPath n_p, IRDocsifier d) -> Doc { // + "tirx", [](tvm::IRModule mod, ffi::reflection::AccessPath n_p, IRDocsifier d) -> Doc { // ffi::Optional doc = d->GetVarDoc(mod); TVM_FFI_ICHECK(doc) << "Unable to print IRModule before definition in TIR."; return doc.value(); diff --git a/src/tirx/script/printer/ir.cc b/src/tirx/script/printer/ir.cc index 57bec5a56136..63af4bab0b4a 100644 --- a/src/tirx/script/printer/ir.cc +++ b/src/tirx/script/printer/ir.cc @@ -27,7 +27,7 @@ namespace printer { TVM_FFI_STATIC_INIT_BLOCK() { TIRFrameNode::RegisterReflection(); } TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) - .set_dispatch("", [](IntImm imm, AccessPath imm_p, IRDocsifier d) -> Doc { + .set_dispatch("", [](IntImm imm, ffi::reflection::AccessPath imm_p, IRDocsifier d) -> Doc { DataType dtype = imm->dtype; if (dtype == d->cfg->int_dtype) { return LiteralDoc::Int(imm->value, imm_p->Attr("value")); @@ -40,7 +40,7 @@ TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) }); TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) - .set_dispatch("", [](FloatImm imm, AccessPath imm_p, IRDocsifier d) -> Doc { + .set_dispatch("", [](FloatImm imm, ffi::reflection::AccessPath imm_p, IRDocsifier d) -> Doc { DataType dtype = imm->dtype; if (dtype == d->cfg->float_dtype) { return LiteralDoc::Float(imm->value, imm_p->Attr("value")); @@ -51,7 +51,7 @@ TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) }); TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) - .set_dispatch("tirx", [](Range range, AccessPath p, IRDocsifier d) -> Doc { + .set_dispatch("tirx", [](Range range, ffi::reflection::AccessPath p, IRDocsifier d) -> Doc { return TIR(d, "Range") ->Call({ d->AsDoc(range->min, p->Attr("min")), @@ -60,12 +60,12 @@ TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) }); TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) - .set_dispatch("", [](PrimType ty, AccessPath p, IRDocsifier d) -> Doc { + .set_dispatch("", [](PrimType ty, ffi::reflection::AccessPath p, IRDocsifier d) -> Doc { return TIR(d, DType2Str(ty->dtype)); }); TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) - .set_dispatch("", [](PointerType ty, AccessPath ty_p, IRDocsifier d) -> Doc { + .set_dispatch("", [](PointerType ty, ffi::reflection::AccessPath ty_p, IRDocsifier d) -> Doc { ExprDoc element_type{ffi::UnsafeInit()}; if (const auto* prim_type = ty->element_type.as()) { element_type = LiteralDoc::DataType(prim_type->dtype, // @@ -82,7 +82,7 @@ TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) }); TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) - .set_dispatch("", [](TupleType ty, AccessPath p, IRDocsifier d) -> Doc { + .set_dispatch("", [](TupleType ty, ffi::reflection::AccessPath p, IRDocsifier d) -> Doc { if (ty->fields.empty()) { return LiteralDoc::None(p); } @@ -90,7 +90,7 @@ TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) }); TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) - .set_dispatch("", [](Target target, AccessPath p, IRDocsifier d) -> Doc { + .set_dispatch("", [](Target target, ffi::reflection::AccessPath p, IRDocsifier d) -> Doc { ffi::Map config = target->ToConfig(); return TIR(d, "target")->Call({d->AsDoc(config, p)}); }); diff --git a/src/tirx/script/printer/stmt.cc b/src/tirx/script/printer/stmt.cc index 3c3ab21f9338..46dafdd62588 100644 --- a/src/tirx/script/printer/stmt.cc +++ b/src/tirx/script/printer/stmt.cc @@ -81,7 +81,7 @@ ffi::Optional FindReturnValue(const tirx::Stmt& node) { } TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) - .set_dispatch("", [](tirx::Evaluate eval, AccessPath p, IRDocsifier d) -> Doc { + .set_dispatch("", [](tirx::Evaluate eval, ffi::reflection::AccessPath p, IRDocsifier d) -> Doc { if (d->cfg->syntax_sugar) { if (auto return_value = FindReturnValue(eval)) { ExprDoc value = @@ -98,7 +98,7 @@ TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) }); TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) - .set_dispatch("", [](tirx::Bind stmt, AccessPath p, IRDocsifier d) -> Doc { + .set_dispatch("", [](tirx::Bind stmt, ffi::reflection::AccessPath p, IRDocsifier d) -> Doc { // Step 1. Type annotation ffi::Optional type_doc = d->AsDoc(stmt->var->type_annotation, // p->Attr("var")->Attr("type_annotation")); @@ -122,7 +122,7 @@ TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) .set_dispatch( - "", [](tirx::AssertStmt stmt, AccessPath p, IRDocsifier d) -> Doc { + "", [](tirx::AssertStmt stmt, ffi::reflection::AccessPath p, IRDocsifier d) -> Doc { ExprDoc cond = d->AsDoc(stmt->condition, p->Attr("condition")); // Always emit the canonical tuple form: assert cond, ("Kind", ["part0", "part1", ...]) ffi::Array parts; @@ -135,7 +135,7 @@ TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) }); TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) - .set_dispatch("", [](tirx::While stmt, AccessPath p, IRDocsifier d) -> Doc { + .set_dispatch("", [](tirx::While stmt, ffi::reflection::AccessPath p, IRDocsifier d) -> Doc { ExprDoc cond = d->AsDoc(stmt->condition, p->Attr("condition")); With f(d, stmt); AsDocBody(stmt->body, p->Attr("body"), f->get(), d); @@ -143,7 +143,7 @@ TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) }); namespace { -Doc DeclBufferDoc(tirx::DeclBuffer stmt, AccessPath p, IRDocsifier d, +Doc DeclBufferDoc(tirx::DeclBuffer stmt, ffi::reflection::AccessPath p, IRDocsifier d, BufferVarDefinition var_definitions) { ExprDoc rhs = BufferDecl(stmt->buffer, "decl_buffer", {}, p->Attr("buffer"), d->frames.back(), d, var_definitions); @@ -154,13 +154,13 @@ Doc DeclBufferDoc(tirx::DeclBuffer stmt, AccessPath p, IRDocsifier d, TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) .set_dispatch( // - "", [](tirx::DeclBuffer stmt, AccessPath p, IRDocsifier d) -> Doc { + "", [](tirx::DeclBuffer stmt, ffi::reflection::AccessPath p, IRDocsifier d) -> Doc { return DeclBufferDoc(stmt, p, d, BufferVarDefinition::None); }); TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) .set_dispatch( // - "", [](tirx::IfThenElse stmt, AccessPath p, IRDocsifier d) -> Doc { + "", [](tirx::IfThenElse stmt, ffi::reflection::AccessPath p, IRDocsifier d) -> Doc { ExprDoc cond = d->AsDoc(stmt->condition, p->Attr("condition")); ffi::Array then_branch; ffi::Array else_branch; @@ -178,13 +178,13 @@ TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) }); TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) - .set_dispatch("", [](tirx::SeqStmt stmt, AccessPath p, IRDocsifier d) -> Doc { + .set_dispatch("", [](tirx::SeqStmt stmt, ffi::reflection::AccessPath p, IRDocsifier d) -> Doc { With f(d, stmt); AsDocBody(stmt, p, f->get(), d); return StmtBlockDoc((*f)->stmts); }); -void InsertEnvThread(const tirx::IterVar& iter_var, const AccessPath& iter_var_p, +void InsertEnvThread(const tirx::IterVar& iter_var, const ffi::reflection::AccessPath& iter_var_p, const IRDocsifier& d) { Frame f = FindLowestVarDef(iter_var->var, d).value(); DefineVar(iter_var->var, f, d); @@ -195,10 +195,10 @@ void InsertEnvThread(const tirx::IterVar& iter_var, const AccessPath& iter_var_p f->stmts.push_back(AssignDoc(lhs, rhs, std::nullopt)); } -ExprDoc DocsifyLaunchThread(const tirx::AttrStmt& attr_stmt, const AccessPath& attr_stmt_p, +ExprDoc DocsifyLaunchThread(const tirx::AttrStmt& attr_stmt, const ffi::reflection::AccessPath& attr_stmt_p, ffi::Optional* define_var, const IRDocsifier& d) { tirx::IterVar iter_var = Downcast(attr_stmt->node); - AccessPath iter_var_p = attr_stmt_p->Attr("node"); + ffi::reflection::AccessPath iter_var_p = attr_stmt_p->Attr("node"); ExprDoc var_doc{ffi::UnsafeInit()}; if (d->IsVarDefined(iter_var->var)) { @@ -219,13 +219,13 @@ ExprDoc DocsifyLaunchThread(const tirx::AttrStmt& attr_stmt, const AccessPath& a TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) .set_dispatch( // - "", [](tirx::AttrStmt stmt, AccessPath stmt_p, IRDocsifier d) -> Doc { + "", [](tirx::AttrStmt stmt, ffi::reflection::AccessPath stmt_p, IRDocsifier d) -> Doc { bool concise = AllowConciseScoping(d, stmt); ffi::Optional lhs = std::nullopt; ffi::Optional rhs = std::nullopt; ffi::Optional define_var = std::nullopt; tirx::Stmt body = stmt->body; - AccessPath body_p = stmt_p->Attr("body"); + ffi::reflection::AccessPath body_p = stmt_p->Attr("body"); if (stmt->attr_key == "thread_extent" || stmt->attr_key == "virtual_thread") { if (stmt->node.as()) { rhs = DocsifyLaunchThread(stmt, stmt_p, &define_var, d); @@ -252,9 +252,9 @@ TVM_REGISTER_SCRIPT_AS_REPR(tirx::AssertStmtNode, ReprPrintTIR); TVM_REGISTER_SCRIPT_AS_REPR(tirx::WhileNode, ReprPrintTIR); TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) .set_dispatch( // - "", [](tirx::AllocBuffer stmt, AccessPath p, IRDocsifier d) -> Doc { + "", [](tirx::AllocBuffer stmt, ffi::reflection::AccessPath p, IRDocsifier d) -> Doc { tirx::Buffer buffer = stmt->buffer; - AccessPath buffer_p = p->Attr("buffer"); + ffi::reflection::AccessPath buffer_p = p->Attr("buffer"); Frame frame = d->frames.back(); // Define buffer's data var inline as buffer.data if (!d->IsVarDefined(buffer->data)) { @@ -272,10 +272,10 @@ TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) int n = buffer->shape.size(); ffi::Array shape_docs; shape_docs.reserve(n); - AccessPath shape_p = buffer_p->Attr("shape"); + ffi::reflection::AccessPath shape_p = buffer_p->Attr("shape"); for (int i = 0; i < n; ++i) { PrimExpr e = buffer->shape[i]; - AccessPath e_p = shape_p->ArrayItem(i); + ffi::reflection::AccessPath e_p = shape_p->ArrayItem(i); if (!d->IsVarDefined(e) && e->IsInstance()) { ExprDoc lhs = DefineVar(Downcast(e), frame, d); lhs->source_paths.push_back(e_p); diff --git a/src/tirx/script/printer/utils.h b/src/tirx/script/printer/utils.h index 8dc6e703bccd..3207edf8bfe5 100644 --- a/src/tirx/script/printer/utils.h +++ b/src/tirx/script/printer/utils.h @@ -108,7 +108,7 @@ inline IdDoc DefineBuffer(const tirx::Buffer& buffer, const Frame& frame, const * \param f The frame * \param d The IRDocsifier */ -inline void AsDocBody(const tirx::Stmt& stmt, AccessPath p, TIRFrameNode* f, const IRDocsifier& d) { +inline void AsDocBody(const tirx::Stmt& stmt, ffi::reflection::AccessPath p, TIRFrameNode* f, const IRDocsifier& d) { if (const auto* seq_stmt = stmt.as()) { ffi::Array body = seq_stmt->seq; for (int i = 0, n = body.size(); i < n; ++i) { @@ -214,7 +214,7 @@ enum class BufferVarDefinition { * \return The ExprDoc corresponding to the buffer declaration */ ExprDoc BufferDecl(const tirx::Buffer& buffer, const ffi::String& method, - const ffi::Array& args, const AccessPath& p, const Frame& frame, + const ffi::Array& args, const ffi::reflection::AccessPath& p, const Frame& frame, const IRDocsifier& d, BufferVarDefinition var_definitions); /*! @@ -225,7 +225,7 @@ ExprDoc BufferDecl(const tirx::Buffer& buffer, const ffi::String& method, * \param d The IRDocsifier * \return The ExprDoc corresponding to the buffer declaration */ -ExprDoc BufferAttn(const tirx::Buffer& buffer, const AccessPath& p, const Frame& frame, +ExprDoc BufferAttn(const tirx::Buffer& buffer, const ffi::reflection::AccessPath& p, const Frame& frame, const IRDocsifier& d); /*! @@ -235,7 +235,7 @@ ExprDoc BufferAttn(const tirx::Buffer& buffer, const AccessPath& p, const Frame& * \param d The IRDocsifier * \return The ExprDoc corresponding to the Var creation */ -ExprDoc PrintVarCreation(const tirx::Var& var, const AccessPath& var_p, const IRDocsifier& d); +ExprDoc PrintVarCreation(const tirx::Var& var, const ffi::reflection::AccessPath& var_p, const IRDocsifier& d); /*! \brief A Var occurrence counter visitor */ class OccurrenceCounter : public tirx::StmtExprVisitor { diff --git a/src/tirx/transform/ir_utils.cc b/src/tirx/transform/ir_utils.cc index 9130bca9c091..96b0209415ba 100644 --- a/src/tirx/transform/ir_utils.cc +++ b/src/tirx/transform/ir_utils.cc @@ -682,13 +682,13 @@ ffi::Array ConvertIndices(const MatchBufferRegion& match_buffer, return result; } -Region ConvertRegion(const MatchBufferRegion& match_buffer, const Region& region) { +ffi::Array ConvertRegion(const MatchBufferRegion& match_buffer, const ffi::Array& region) { const Buffer& target = match_buffer->buffer; const BufferRegion& source = match_buffer->source; TVM_FFI_ICHECK_EQ(region.size(), target->shape.size()); arith::Analyzer analyzer; - Region result; + ffi::Array result; result.reserve(source->region.size()); size_t offset = source->region.size() - region.size(); for (size_t i = 0; i < offset; ++i) { diff --git a/src/tirx/transform/ir_utils.h b/src/tirx/transform/ir_utils.h index f77d73fbcff0..6427ae43a2e2 100644 --- a/src/tirx/transform/ir_utils.h +++ b/src/tirx/transform/ir_utils.h @@ -227,7 +227,7 @@ ffi::Array ConvertIndices(const MatchBufferRegion& match_buffer, * \param region The sub-region of the target buffer * \return The region of source buffer. */ -Region ConvertRegion(const MatchBufferRegion& match_buffer, const Region& region); +ffi::Array ConvertRegion(const MatchBufferRegion& match_buffer, const ffi::Array& region); /*! * \brief Get stride aware buffer allocation shape from buffer. From 7c668e4f2ad08b7e334540403b3b966a9e0e0445 Mon Sep 17 00:00:00 2001 From: tqchen Date: Thu, 30 Apr 2026 13:08:22 +0000 Subject: [PATCH 2/4] [REFACTOR][RUNTIME] Switch icheck-only callers from runtime/logging.h to tvm/ffi/error.h MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit For 64 files (7 in include/, 57 in src/) that include but only use TVM_FFI_ICHECK / TVM_FFI_THROW macros (no LOG/VLOG/DLOG logging features), switch the include to which provides those macros directly. Files that use LOG(WARNING), LOG(INFO), VLOG, DLOG, VLOG_CONTEXT, or similar logging features keep . No ICHECK → TVM_FFI_ICHECK rewrites needed: bare ICHECK is already absent across include/ and src/ (only TVM_FFI_ICHECK is used). Also adds explicit to ~70 src/ files that use LOG/VLOG/DLOG macros but previously relied on a transitive include chain through include/tvm/ir/op.h (which switched from logging.h to ffi/error.h). Making these includes explicit is correct hygiene regardless; switching op.h's include only makes it visible. --- include/tvm/ir/node_functor.h | 2 +- include/tvm/ir/op.h | 2 +- include/tvm/relax/utils.h | 2 +- include/tvm/runtime/device_api.h | 2 +- include/tvm/runtime/vm/bytecode.h | 2 +- include/tvm/s_tir/random_engine.h | 2 +- include/tvm/topi/detail/constant_utils.h | 1 + src/arith/analyzer.cc | 1 + src/arith/const_fold.h | 1 + src/arith/int_set.cc | 1 + src/arith/solve_linear_equation.cc | 1 + src/ir/diagnostic.cc | 1 + src/ir/instrument.cc | 2 + src/ir/repr.cc | 6 +- src/ir/source_map.cc | 1 + src/ir/transform.cc | 1 + src/relax/analysis/graph_partitioner.h | 1 + src/relax/analysis/layout_transformation.cc | 1 + src/relax/analysis/well_formed.cc | 1 + src/relax/backend/contrib/clml/codegen.cc | 1 + src/relax/backend/contrib/cutlass/codegen.cc | 1 + src/relax/backend/contrib/tensorrt/codegen.cc | 1 + src/relax/backend/vm/codegen_vm_tir.cc | 1 + src/relax/ir/block_builder.cc | 1 + src/relax/ir/dataflow_matcher.cc | 1 + src/relax/ir/transform.cc | 1 + src/relax/op/tensor/index.cc | 1 + src/relax/op/tensor/manipulate.cc | 1 + src/relax/transform/bundle_model_params.cc | 2 +- .../transform/eliminate_common_subexpr.cc | 1 + src/relax/transform/fold_constant.cc | 1 + src/relax/transform/fuse_ops.cc | 1 + src/relax/transform/lambda_lift.cc | 2 +- src/relax/transform/lift_transform_params.cc | 2 +- src/relax/transform/meta_schedule.cc | 1 + src/runtime/const_loader_module.cc | 1 + src/runtime/contrib/cblas/cblas.cc | 2 +- src/runtime/contrib/cblas/dnnl_blas.cc | 2 +- src/runtime/contrib/cblas/mkl.cc | 2 +- src/runtime/contrib/clml/clml_runtime.cc | 1 + src/runtime/contrib/cublas/cublas.cc | 2 +- src/runtime/contrib/cublas/cublas_utils.h | 2 +- src/runtime/contrib/cudnn/cudnn_utils.h | 2 +- src/runtime/contrib/dnnl/dnnl_kernel.h | 2 +- .../example_npu/example_npu_runtime.cc | 1 + src/runtime/contrib/hipblas/hipblas.cc | 2 +- src/runtime/contrib/hipblas/hipblas_utils.h | 2 +- .../contrib/random/mt_random_engine.cc | 2 +- src/runtime/contrib/random/random.cc | 2 +- .../contrib/tensorrt/tensorrt_runtime.cc | 1 + src/runtime/cpu_device_api.cc | 2 +- .../disco/distributed/socket_session.cc | 1 + src/runtime/disco/nccl/nccl.cc | 1 + src/runtime/file_utils.cc | 2 +- src/runtime/hexagon/hexagon_buffer.cc | 2 + src/runtime/hexagon/hexagon_buffer.h | 2 +- src/runtime/hexagon/hexagon_common.h | 2 +- src/runtime/hexagon/hexagon_thread_manager.cc | 1 + src/runtime/hexagon/hexagon_thread_manager.h | 2 +- src/runtime/hexagon/hexagon_vtcm_pool.cc | 1 + src/runtime/hexagon/hexagon_vtcm_pool.h | 2 +- src/runtime/hexagon/qhl/qhl_wrapper.cc | 2 +- src/runtime/hexagon/rpc/hexagon/rpc_server.cc | 1 + src/runtime/hexagon/rpc/simulator/session.cc | 1 + src/runtime/memory/memory_manager.cc | 1 + src/runtime/memory/naive_allocator.h | 1 + src/runtime/memory/pooled_allocator.h | 1 + src/runtime/metal/metal_common.h | 2 +- src/runtime/minrpc/minrpc_server.h | 2 +- src/runtime/opencl/opencl_common.h | 2 +- .../opencl/opencl_wrapper/opencl_wrapper.cc | 2 +- src/runtime/rpc/rpc_channel.cc | 2 +- src/runtime/rpc/rpc_server_env.cc | 1 + src/runtime/static_library.cc | 1 + src/runtime/static_library.h | 2 +- src/runtime/thread_pool.cc | 2 +- src/runtime/timer.cc | 1 + src/runtime/vm/attn_backend.h | 2 +- src/runtime/vm/bytecode.cc | 2 +- src/runtime/vm/kv_state.h | 2 +- src/runtime/vm/lm_support.cc | 2 +- src/runtime/vm/paged_kv_cache.cc | 2 +- src/runtime/vulkan/spirv_shader.h | 2 +- src/runtime/vulkan/vulkan_common.h | 2 +- src/runtime/vulkan/vulkan_device.h | 2 +- src/runtime/vulkan/vulkan_instance.cc | 1 + .../rewrite_parallel_vectorize_unroll.cc | 1 + .../postproc/rewrite_tensorize.cc | 1 + .../schedule_rule/apply_custom_rule.cc | 1 + .../schedule_rule/multi_level_tiling.cc | 1 + .../space_generator/space_generator.cc | 1 + src/s_tir/meta_schedule/utils.h | 1 + src/s_tir/schedule/concrete_schedule.cc | 1 + .../schedule/primitive/blockize_tensorize.cc | 1 + .../primitive/layout_transformation.cc | 1 + src/s_tir/support/parallel_for.h | 2 +- src/s_tir/support/table_printer.h | 2 +- src/s_tir/transform/inject_double_buffer.cc | 1 + src/s_tir/transform/loop_partition.cc | 1 + src/s_tir/transform/lower_async_dma.cc | 1 + src/script/ir_builder/ir/ir.cc | 1 + src/support/base64.h | 2 +- src/support/pipe.h | 2 +- src/support/ring_buffer.h | 2 +- src/target/canonicalizer/llvm/arm_aprofile.cc | 2 + src/target/cuda/codegen_cuda.cc | 1 + src/target/cuda/ptx.h | 2 +- src/target/hexagon/llvm/codegen_hexagon.cc | 1 + src/target/intrin_rule.cc | 1 + src/target/llvm/codegen_aarch64.cc | 1 + src/target/llvm/codegen_cpu.cc | 1 + src/target/metal/codegen_metal.cc | 1 + src/target/rocm/llvm/codegen_amdgpu.cc | 1 + src/target/target_kind.cc | 1 + src/tirx/transform/lower_intrin.cc | 1 + src/tirx/transform/lower_tvm_builtin.cc | 1 + src/tirx/transform/storage_rewrite.cc | 1 + src/tirx/transform/tvm_ffi_binder.cc | 63 ++++++++++--------- src/tirx/transform/tvm_ffi_binder.h | 38 +++++------ src/tirx/transform/vectorize_loop.cc | 1 + 120 files changed, 174 insertions(+), 101 deletions(-) diff --git a/include/tvm/ir/node_functor.h b/include/tvm/ir/node_functor.h index c27dc9ec4b87..c7be2188d314 100644 --- a/include/tvm/ir/node_functor.h +++ b/include/tvm/ir/node_functor.h @@ -23,7 +23,7 @@ #ifndef TVM_IR_NODE_FUNCTOR_H_ #define TVM_IR_NODE_FUNCTOR_H_ -#include +#include #include #include diff --git a/include/tvm/ir/op.h b/include/tvm/ir/op.h index d8b2cf07d11f..dc8f99cd4789 100644 --- a/include/tvm/ir/op.h +++ b/include/tvm/ir/op.h @@ -25,6 +25,7 @@ #ifndef TVM_IR_OP_H_ #define TVM_IR_OP_H_ +#include #include #include #include @@ -32,7 +33,6 @@ #include #include #include -#include #include #include diff --git a/include/tvm/relax/utils.h b/include/tvm/relax/utils.h index 792f7dd11f90..bfbcaa069818 100644 --- a/include/tvm/relax/utils.h +++ b/include/tvm/relax/utils.h @@ -25,9 +25,9 @@ #define TVM_RELAX_UTILS_H_ #include +#include #include #include -#include namespace tvm { namespace relax { diff --git a/include/tvm/runtime/device_api.h b/include/tvm/runtime/device_api.h index 47607c5b8875..be5d4e89005b 100644 --- a/include/tvm/runtime/device_api.h +++ b/include/tvm/runtime/device_api.h @@ -25,10 +25,10 @@ #define TVM_RUNTIME_DEVICE_API_H_ #include +#include #include #include #include -#include #include /*! diff --git a/include/tvm/runtime/vm/bytecode.h b/include/tvm/runtime/vm/bytecode.h index 5a60febf8443..0f1927e0cbcb 100644 --- a/include/tvm/runtime/vm/bytecode.h +++ b/include/tvm/runtime/vm/bytecode.h @@ -24,8 +24,8 @@ #ifndef TVM_RUNTIME_VM_BYTECODE_H_ #define TVM_RUNTIME_VM_BYTECODE_H_ +#include #include -#include #include #include diff --git a/include/tvm/s_tir/random_engine.h b/include/tvm/s_tir/random_engine.h index d594e1ba0c35..0acfd50fbed2 100644 --- a/include/tvm/s_tir/random_engine.h +++ b/include/tvm/s_tir/random_engine.h @@ -23,7 +23,7 @@ */ #ifndef TVM_S_TIR_RANDOM_ENGINE_H_ #define TVM_S_TIR_RANDOM_ENGINE_H_ -#include +#include #include #include diff --git a/include/tvm/topi/detail/constant_utils.h b/include/tvm/topi/detail/constant_utils.h index a77177984734..07df5c470bf4 100644 --- a/include/tvm/topi/detail/constant_utils.h +++ b/include/tvm/topi/detail/constant_utils.h @@ -25,6 +25,7 @@ #define TVM_TOPI_DETAIL_CONSTANT_UTILS_H_ #include +#include #include #include #include diff --git a/src/arith/analyzer.cc b/src/arith/analyzer.cc index 4650cfb43b1c..8bce80f4ef8f 100644 --- a/src/arith/analyzer.cc +++ b/src/arith/analyzer.cc @@ -24,6 +24,7 @@ #include #include #include +#include #include #include diff --git a/src/arith/const_fold.h b/src/arith/const_fold.h index 8464443118f9..91db540f2e82 100644 --- a/src/arith/const_fold.h +++ b/src/arith/const_fold.h @@ -25,6 +25,7 @@ #define TVM_ARITH_CONST_FOLD_H_ #include +#include #include #include diff --git a/src/arith/int_set.cc b/src/arith/int_set.cc index 66c148d47857..6b3e2b953270 100644 --- a/src/arith/int_set.cc +++ b/src/arith/int_set.cc @@ -26,6 +26,7 @@ #include #include #include +#include #include #include #include diff --git a/src/arith/solve_linear_equation.cc b/src/arith/solve_linear_equation.cc index 8d6b58351359..4b6ac036e8bb 100644 --- a/src/arith/solve_linear_equation.cc +++ b/src/arith/solve_linear_equation.cc @@ -27,6 +27,7 @@ #include #include #include +#include #include #include #include diff --git a/src/ir/diagnostic.cc b/src/ir/diagnostic.cc index 18fa77f62658..f234963c9261 100644 --- a/src/ir/diagnostic.cc +++ b/src/ir/diagnostic.cc @@ -25,6 +25,7 @@ #include #include #include +#include namespace tvm { diff --git a/src/ir/instrument.cc b/src/ir/instrument.cc index ad47ccf2ed44..e88713a50632 100644 --- a/src/ir/instrument.cc +++ b/src/ir/instrument.cc @@ -26,8 +26,10 @@ #include #include #include +#include #include +#include #include namespace tvm { diff --git a/src/ir/repr.cc b/src/ir/repr.cc index cf15ecbbf685..addbd33209f3 100644 --- a/src/ir/repr.cc +++ b/src/ir/repr.cc @@ -24,7 +24,7 @@ * The legacy ReprPrinter has been replaced by ffi::ReprPrint. This file: * - Implements the Dump() debug helpers (they call ffi::ReprPrint). * - Registers node.AsRepr (for backward Python compatibility) via ffi::ReprPrint. - * - Registers __ffi_repr__ hooks for AccessPath and AccessStep. + * - Registers __ffi_repr__ hooks for ffi::reflection::AccessPath and AccessStep. */ #include #include @@ -48,7 +48,7 @@ TVM_FFI_STATIC_INIT_BLOCK() { // Python's tvm.runtime._ffi_node_api sets __object_repr__ = AsRepr via init_ffi_api. refl::GlobalDef().def("node.AsRepr", [](ffi::Any obj) -> ffi::String { return ffi::ReprPrint(obj); }); - // Register __ffi_repr__ for AccessPath/AccessStep so that ffi.ReprPrint + // Register __ffi_repr__ for ffi::reflection::AccessPath/AccessStep so that ffi.ReprPrint // uses the concise ".field[idx]" format. // // AccessStep: format one step fragment (e.g. ".field", "[0]", "[key]?"). @@ -79,7 +79,7 @@ TVM_FFI_STATIC_INIT_BLOCK() { } return os.str(); }); - // AccessPath: recurse through parent via fn_repr rather than walking the + // ffi::reflection::AccessPath: recurse through parent via fn_repr rather than walking the // linked list manually. Root (no step) emits ""; each non-root node // prepends its parent's repr and appends the current step's repr. refl::TypeAttrDef().def( diff --git a/src/ir/source_map.cc b/src/ir/source_map.cc index e61cd8db753d..96f4b8fda973 100644 --- a/src/ir/source_map.cc +++ b/src/ir/source_map.cc @@ -25,6 +25,7 @@ #include #include #include +#include #include diff --git a/src/ir/transform.cc b/src/ir/transform.cc index c301037732d2..82c3f13c5618 100644 --- a/src/ir/transform.cc +++ b/src/ir/transform.cc @@ -29,6 +29,7 @@ #include #include #include +#include #include diff --git a/src/relax/analysis/graph_partitioner.h b/src/relax/analysis/graph_partitioner.h index 7084139e299b..1ae994842b1f 100644 --- a/src/relax/analysis/graph_partitioner.h +++ b/src/relax/analysis/graph_partitioner.h @@ -28,6 +28,7 @@ #include #include #include +#include #include #include diff --git a/src/relax/analysis/layout_transformation.cc b/src/relax/analysis/layout_transformation.cc index 4fa4ed48534e..dcee90c9a7ec 100644 --- a/src/relax/analysis/layout_transformation.cc +++ b/src/relax/analysis/layout_transformation.cc @@ -27,6 +27,7 @@ #include #include #include +#include #include #include #include diff --git a/src/relax/analysis/well_formed.cc b/src/relax/analysis/well_formed.cc index 875489d43815..ec654c6eb0ef 100644 --- a/src/relax/analysis/well_formed.cc +++ b/src/relax/analysis/well_formed.cc @@ -72,6 +72,7 @@ #include #include #include +#include #include #include diff --git a/src/relax/backend/contrib/clml/codegen.cc b/src/relax/backend/contrib/clml/codegen.cc index 24122ddf5241..eaa57f8315e4 100644 --- a/src/relax/backend/contrib/clml/codegen.cc +++ b/src/relax/backend/contrib/clml/codegen.cc @@ -26,6 +26,7 @@ #include #include #include +#include #include #include diff --git a/src/relax/backend/contrib/cutlass/codegen.cc b/src/relax/backend/contrib/cutlass/codegen.cc index 4a91dcab9cff..91840f6936e5 100644 --- a/src/relax/backend/contrib/cutlass/codegen.cc +++ b/src/relax/backend/contrib/cutlass/codegen.cc @@ -29,6 +29,7 @@ #include #include #include +#include #include #include diff --git a/src/relax/backend/contrib/tensorrt/codegen.cc b/src/relax/backend/contrib/tensorrt/codegen.cc index 011b8138a595..2be214ed941c 100644 --- a/src/relax/backend/contrib/tensorrt/codegen.cc +++ b/src/relax/backend/contrib/tensorrt/codegen.cc @@ -27,6 +27,7 @@ #include #include #include +#include #include #include diff --git a/src/relax/backend/vm/codegen_vm_tir.cc b/src/relax/backend/vm/codegen_vm_tir.cc index b143b1473d05..10da7d983619 100644 --- a/src/relax/backend/vm/codegen_vm_tir.cc +++ b/src/relax/backend/vm/codegen_vm_tir.cc @@ -27,6 +27,7 @@ #include #include #include +#include #include #include #include diff --git a/src/relax/ir/block_builder.cc b/src/relax/ir/block_builder.cc index b2a70a39c266..1061c02eb1f8 100644 --- a/src/relax/ir/block_builder.cc +++ b/src/relax/ir/block_builder.cc @@ -33,6 +33,7 @@ #include #include #include +#include #include #include diff --git a/src/relax/ir/dataflow_matcher.cc b/src/relax/ir/dataflow_matcher.cc index 37e90c530930..22e3a7bbc31a 100644 --- a/src/relax/ir/dataflow_matcher.cc +++ b/src/relax/ir/dataflow_matcher.cc @@ -33,6 +33,7 @@ #include #include #include +#include #include #include diff --git a/src/relax/ir/transform.cc b/src/relax/ir/transform.cc index a7d7047c3095..4b4c7077c64d 100644 --- a/src/relax/ir/transform.cc +++ b/src/relax/ir/transform.cc @@ -30,6 +30,7 @@ #include #include #include +#include namespace tvm { namespace relax { diff --git a/src/relax/op/tensor/index.cc b/src/relax/op/tensor/index.cc index f954e901a8ea..efa221fd64f3 100644 --- a/src/relax/op/tensor/index.cc +++ b/src/relax/op/tensor/index.cc @@ -26,6 +26,7 @@ #include #include +#include #include #include diff --git a/src/relax/op/tensor/manipulate.cc b/src/relax/op/tensor/manipulate.cc index b8045a12c2cf..c0b82a760d13 100644 --- a/src/relax/op/tensor/manipulate.cc +++ b/src/relax/op/tensor/manipulate.cc @@ -26,6 +26,7 @@ #include #include +#include #include #include diff --git a/src/relax/transform/bundle_model_params.cc b/src/relax/transform/bundle_model_params.cc index b8b4825e35ba..b4e4f186d19d 100644 --- a/src/relax/transform/bundle_model_params.cc +++ b/src/relax/transform/bundle_model_params.cc @@ -23,12 +23,12 @@ */ #include +#include #include #include #include #include #include -#include #include "utils.h" diff --git a/src/relax/transform/eliminate_common_subexpr.cc b/src/relax/transform/eliminate_common_subexpr.cc index 7be779984ce9..0d0b8de82a1d 100644 --- a/src/relax/transform/eliminate_common_subexpr.cc +++ b/src/relax/transform/eliminate_common_subexpr.cc @@ -32,6 +32,7 @@ #include #include #include +#include #include "../../support/utils.h" diff --git a/src/relax/transform/fold_constant.cc b/src/relax/transform/fold_constant.cc index 934d93edf494..ed28e5dbc8da 100644 --- a/src/relax/transform/fold_constant.cc +++ b/src/relax/transform/fold_constant.cc @@ -26,6 +26,7 @@ #include #include #include +#include #include #include diff --git a/src/relax/transform/fuse_ops.cc b/src/relax/transform/fuse_ops.cc index 5803dad48514..7af1bb0c8a6a 100644 --- a/src/relax/transform/fuse_ops.cc +++ b/src/relax/transform/fuse_ops.cc @@ -35,6 +35,7 @@ #include #include #include +#include #include #include #include diff --git a/src/relax/transform/lambda_lift.cc b/src/relax/transform/lambda_lift.cc index 387a22b7b385..6fb78bfb1422 100644 --- a/src/relax/transform/lambda_lift.cc +++ b/src/relax/transform/lambda_lift.cc @@ -23,12 +23,12 @@ */ #include +#include #include #include #include #include #include -#include #include #include diff --git a/src/relax/transform/lift_transform_params.cc b/src/relax/transform/lift_transform_params.cc index 9be91ace1e01..5430e181bca4 100644 --- a/src/relax/transform/lift_transform_params.cc +++ b/src/relax/transform/lift_transform_params.cc @@ -23,13 +23,13 @@ */ #include +#include #include #include #include #include #include #include -#include #include #include diff --git a/src/relax/transform/meta_schedule.cc b/src/relax/transform/meta_schedule.cc index 3d1df9773f5d..a9dd126a3e61 100644 --- a/src/relax/transform/meta_schedule.cc +++ b/src/relax/transform/meta_schedule.cc @@ -24,6 +24,7 @@ #include #include #include +#include #include #include diff --git a/src/runtime/const_loader_module.cc b/src/runtime/const_loader_module.cc index b7ce95dd2dbb..8592d228d0bd 100644 --- a/src/runtime/const_loader_module.cc +++ b/src/runtime/const_loader_module.cc @@ -36,6 +36,7 @@ #include #include #include +#include #include #include diff --git a/src/runtime/contrib/cblas/cblas.cc b/src/runtime/contrib/cblas/cblas.cc index a91db72e5dab..926ce0195245 100644 --- a/src/runtime/contrib/cblas/cblas.cc +++ b/src/runtime/contrib/cblas/cblas.cc @@ -20,10 +20,10 @@ /*! * \file Use external cblas library call. */ +#include #include #include #include -#include extern "C" { #include diff --git a/src/runtime/contrib/cblas/dnnl_blas.cc b/src/runtime/contrib/cblas/dnnl_blas.cc index d6a9baa21bc8..420e244301b2 100644 --- a/src/runtime/contrib/cblas/dnnl_blas.cc +++ b/src/runtime/contrib/cblas/dnnl_blas.cc @@ -20,10 +20,10 @@ /*! * \file Use external cblas library call. */ +#include #include #include #include -#include extern "C" { #include diff --git a/src/runtime/contrib/cblas/mkl.cc b/src/runtime/contrib/cblas/mkl.cc index 59783134157c..60fecc11bd66 100644 --- a/src/runtime/contrib/cblas/mkl.cc +++ b/src/runtime/contrib/cblas/mkl.cc @@ -20,10 +20,10 @@ /*! * \file Use external mkl library call. */ +#include #include #include #include -#include extern "C" { #include diff --git a/src/runtime/contrib/clml/clml_runtime.cc b/src/runtime/contrib/clml/clml_runtime.cc index 5ea6c1398eeb..dd66987bdd11 100644 --- a/src/runtime/contrib/clml/clml_runtime.cc +++ b/src/runtime/contrib/clml/clml_runtime.cc @@ -36,6 +36,7 @@ #include "clml_utils.h" #endif +#include #include namespace tvm { diff --git a/src/runtime/contrib/cublas/cublas.cc b/src/runtime/contrib/cublas/cublas.cc index dcaf93d2da2e..e58ffdeee0ba 100644 --- a/src/runtime/contrib/cublas/cublas.cc +++ b/src/runtime/contrib/cublas/cublas.cc @@ -20,11 +20,11 @@ /*! * \file Use external cblas library call. */ +#include #include #include #include #include -#include #include "../../3rdparty/compiler-rt/builtin_fp16.h" #include "../cblas/gemm_common.h" diff --git a/src/runtime/contrib/cublas/cublas_utils.h b/src/runtime/contrib/cublas/cublas_utils.h index 429e9831146b..ad67eb1ee9e8 100644 --- a/src/runtime/contrib/cublas/cublas_utils.h +++ b/src/runtime/contrib/cublas/cublas_utils.h @@ -28,7 +28,7 @@ #include #include #include -#include +#include #include #if CUDART_VERSION >= 10010 diff --git a/src/runtime/contrib/cudnn/cudnn_utils.h b/src/runtime/contrib/cudnn/cudnn_utils.h index 58eac57c679d..91f50dfc1c92 100644 --- a/src/runtime/contrib/cudnn/cudnn_utils.h +++ b/src/runtime/contrib/cudnn/cudnn_utils.h @@ -25,8 +25,8 @@ #define TVM_RUNTIME_CONTRIB_CUDNN_CUDNN_UTILS_H_ #include +#include #include -#include #include diff --git a/src/runtime/contrib/dnnl/dnnl_kernel.h b/src/runtime/contrib/dnnl/dnnl_kernel.h index a407f5589c61..b7a0ec0f7314 100644 --- a/src/runtime/contrib/dnnl/dnnl_kernel.h +++ b/src/runtime/contrib/dnnl/dnnl_kernel.h @@ -25,9 +25,9 @@ #ifndef TVM_RUNTIME_CONTRIB_DNNL_DNNL_KERNEL_H_ #define TVM_RUNTIME_CONTRIB_DNNL_DNNL_KERNEL_H_ +#include #include #include -#include #include diff --git a/src/runtime/contrib/example_npu/example_npu_runtime.cc b/src/runtime/contrib/example_npu/example_npu_runtime.cc index 440a5d9715ec..0408a3fe9acd 100644 --- a/src/runtime/contrib/example_npu/example_npu_runtime.cc +++ b/src/runtime/contrib/example_npu/example_npu_runtime.cc @@ -31,6 +31,7 @@ #include #include +#include #include #include diff --git a/src/runtime/contrib/hipblas/hipblas.cc b/src/runtime/contrib/hipblas/hipblas.cc index b2cc7331117a..eca971e06606 100644 --- a/src/runtime/contrib/hipblas/hipblas.cc +++ b/src/runtime/contrib/hipblas/hipblas.cc @@ -20,10 +20,10 @@ /*! * \file Use external hipblas library call. */ +#include #include #include #include -#include #include "../../3rdparty/compiler-rt/builtin_fp16.h" #include "../cblas/gemm_common.h" diff --git a/src/runtime/contrib/hipblas/hipblas_utils.h b/src/runtime/contrib/hipblas/hipblas_utils.h index a44c984d9a3f..90c8c489d370 100644 --- a/src/runtime/contrib/hipblas/hipblas_utils.h +++ b/src/runtime/contrib/hipblas/hipblas_utils.h @@ -28,7 +28,7 @@ #include #include #include -#include +#include #include diff --git a/src/runtime/contrib/random/mt_random_engine.cc b/src/runtime/contrib/random/mt_random_engine.cc index 64c3ff66a7cf..c01fe9267326 100644 --- a/src/runtime/contrib/random/mt_random_engine.cc +++ b/src/runtime/contrib/random/mt_random_engine.cc @@ -22,9 +22,9 @@ * \brief mt19937 random engine */ #include +#include #include #include -#include #include #include diff --git a/src/runtime/contrib/random/random.cc b/src/runtime/contrib/random/random.cc index 1f3fdf869e99..af94f97ef16f 100644 --- a/src/runtime/contrib/random/random.cc +++ b/src/runtime/contrib/random/random.cc @@ -20,10 +20,10 @@ /*! * \file External random functions for tensor. */ +#include #include #include #include -#include #include diff --git a/src/runtime/contrib/tensorrt/tensorrt_runtime.cc b/src/runtime/contrib/tensorrt/tensorrt_runtime.cc index df8443dd590a..d4fcffd541bb 100644 --- a/src/runtime/contrib/tensorrt/tensorrt_runtime.cc +++ b/src/runtime/contrib/tensorrt/tensorrt_runtime.cc @@ -25,6 +25,7 @@ #include #include #include +#include #include #include diff --git a/src/runtime/cpu_device_api.cc b/src/runtime/cpu_device_api.cc index 9762d2c3b46f..b549182dab31 100644 --- a/src/runtime/cpu_device_api.cc +++ b/src/runtime/cpu_device_api.cc @@ -20,10 +20,10 @@ /*! * \file cpu_device_api.cc */ +#include #include #include #include -#include #include #include diff --git a/src/runtime/disco/distributed/socket_session.cc b/src/runtime/disco/distributed/socket_session.cc index d0b8d7df1640..a9d9d912aa82 100644 --- a/src/runtime/disco/distributed/socket_session.cc +++ b/src/runtime/disco/distributed/socket_session.cc @@ -18,6 +18,7 @@ */ #include #include +#include #include diff --git a/src/runtime/disco/nccl/nccl.cc b/src/runtime/disco/nccl/nccl.cc index 1230cf15f8a7..3167ab243ca7 100644 --- a/src/runtime/disco/nccl/nccl.cc +++ b/src/runtime/disco/nccl/nccl.cc @@ -18,6 +18,7 @@ */ #include +#include #include #include diff --git a/src/runtime/file_utils.cc b/src/runtime/file_utils.cc index 095b0288bad8..180f04da7dd7 100644 --- a/src/runtime/file_utils.cc +++ b/src/runtime/file_utils.cc @@ -22,10 +22,10 @@ */ #include "file_utils.h" +#include #include #include #include -#include #include #include diff --git a/src/runtime/hexagon/hexagon_buffer.cc b/src/runtime/hexagon/hexagon_buffer.cc index 84297c25e66b..b77638e45176 100644 --- a/src/runtime/hexagon/hexagon_buffer.cc +++ b/src/runtime/hexagon/hexagon_buffer.cc @@ -18,6 +18,8 @@ */ #include "hexagon_buffer.h" +#include + #include #include #include diff --git a/src/runtime/hexagon/hexagon_buffer.h b/src/runtime/hexagon/hexagon_buffer.h index 2dd7c127e3ed..0e578fccf477 100644 --- a/src/runtime/hexagon/hexagon_buffer.h +++ b/src/runtime/hexagon/hexagon_buffer.h @@ -20,10 +20,10 @@ #ifndef TVM_RUNTIME_HEXAGON_HEXAGON_BUFFER_H_ #define TVM_RUNTIME_HEXAGON_HEXAGON_BUFFER_H_ +#include #include #include #include -#include #include #include diff --git a/src/runtime/hexagon/hexagon_common.h b/src/runtime/hexagon/hexagon_common.h index 335e611d603e..7ffc4457192a 100644 --- a/src/runtime/hexagon/hexagon_common.h +++ b/src/runtime/hexagon/hexagon_common.h @@ -26,7 +26,7 @@ #include #include #include -#include +#include #if defined(__hexagon__) #include diff --git a/src/runtime/hexagon/hexagon_thread_manager.cc b/src/runtime/hexagon/hexagon_thread_manager.cc index c1c3eadc3126..76e57c67e8a1 100644 --- a/src/runtime/hexagon/hexagon_thread_manager.cc +++ b/src/runtime/hexagon/hexagon_thread_manager.cc @@ -18,6 +18,7 @@ */ #include "hexagon_thread_manager.h" +#include namespace tvm { namespace runtime { diff --git a/src/runtime/hexagon/hexagon_thread_manager.h b/src/runtime/hexagon/hexagon_thread_manager.h index 83c5316a7259..c02e23f29c34 100644 --- a/src/runtime/hexagon/hexagon_thread_manager.h +++ b/src/runtime/hexagon/hexagon_thread_manager.h @@ -22,7 +22,7 @@ #include #include -#include +#include #include #include diff --git a/src/runtime/hexagon/hexagon_vtcm_pool.cc b/src/runtime/hexagon/hexagon_vtcm_pool.cc index ef3dc592f003..f96ba975da0d 100644 --- a/src/runtime/hexagon/hexagon_vtcm_pool.cc +++ b/src/runtime/hexagon/hexagon_vtcm_pool.cc @@ -17,6 +17,7 @@ * under the License. */ #include "hexagon_vtcm_pool.h" +#include #include "HAP_compute_res.h" #include "hexagon_common.h" diff --git a/src/runtime/hexagon/hexagon_vtcm_pool.h b/src/runtime/hexagon/hexagon_vtcm_pool.h index 0f7153eb54f6..5159c458c8d6 100644 --- a/src/runtime/hexagon/hexagon_vtcm_pool.h +++ b/src/runtime/hexagon/hexagon_vtcm_pool.h @@ -23,7 +23,7 @@ #include #include #include -#include +#include #include #include diff --git a/src/runtime/hexagon/qhl/qhl_wrapper.cc b/src/runtime/hexagon/qhl/qhl_wrapper.cc index e1515ecc7e08..a90b8eb1618f 100644 --- a/src/runtime/hexagon/qhl/qhl_wrapper.cc +++ b/src/runtime/hexagon/qhl/qhl_wrapper.cc @@ -19,7 +19,7 @@ #if defined(__hexagon__) #include #include -#include +#include #define restrict __restrict__ #define LOG2VLEN 7 diff --git a/src/runtime/hexagon/rpc/hexagon/rpc_server.cc b/src/runtime/hexagon/rpc/hexagon/rpc_server.cc index cd78591b4dbb..9f20a8f6d229 100644 --- a/src/runtime/hexagon/rpc/hexagon/rpc_server.cc +++ b/src/runtime/hexagon/rpc/hexagon/rpc_server.cc @@ -29,6 +29,7 @@ extern "C" { #include #include #include +#include #include #include diff --git a/src/runtime/hexagon/rpc/simulator/session.cc b/src/runtime/hexagon/rpc/simulator/session.cc index 0864796a9ad9..918614afcde7 100644 --- a/src/runtime/hexagon/rpc/simulator/session.cc +++ b/src/runtime/hexagon/rpc/simulator/session.cc @@ -20,6 +20,7 @@ #include #include #include +#include // POSIX includes #include #include diff --git a/src/runtime/memory/memory_manager.cc b/src/runtime/memory/memory_manager.cc index ba96c0071e0d..626222e6c87f 100644 --- a/src/runtime/memory/memory_manager.cc +++ b/src/runtime/memory/memory_manager.cc @@ -25,6 +25,7 @@ #include #include #include +#include #include #include diff --git a/src/runtime/memory/naive_allocator.h b/src/runtime/memory/naive_allocator.h index 6a968c86ef3b..134d05762286 100644 --- a/src/runtime/memory/naive_allocator.h +++ b/src/runtime/memory/naive_allocator.h @@ -24,6 +24,7 @@ #define TVM_RUNTIME_MEMORY_NAIVE_ALLOCATOR_H_ #include +#include #include #include diff --git a/src/runtime/memory/pooled_allocator.h b/src/runtime/memory/pooled_allocator.h index 620393466867..2862dde1ae6d 100644 --- a/src/runtime/memory/pooled_allocator.h +++ b/src/runtime/memory/pooled_allocator.h @@ -24,6 +24,7 @@ #define TVM_RUNTIME_MEMORY_POOLED_ALLOCATOR_H_ #include +#include #include #include diff --git a/src/runtime/metal/metal_common.h b/src/runtime/metal/metal_common.h index cc538f84dce0..ebbbcde071b4 100644 --- a/src/runtime/metal/metal_common.h +++ b/src/runtime/metal/metal_common.h @@ -33,7 +33,7 @@ #include #include #include -#include +#include #include #include diff --git a/src/runtime/minrpc/minrpc_server.h b/src/runtime/minrpc/minrpc_server.h index 434c88b693e9..84cf45a6ca92 100644 --- a/src/runtime/minrpc/minrpc_server.h +++ b/src/runtime/minrpc/minrpc_server.h @@ -31,7 +31,7 @@ #include #include #include -#include +#include #include #include diff --git a/src/runtime/opencl/opencl_common.h b/src/runtime/opencl/opencl_common.h index a43b29d5ec59..d80a52e5e705 100644 --- a/src/runtime/opencl/opencl_common.h +++ b/src/runtime/opencl/opencl_common.h @@ -27,7 +27,7 @@ #include #include #include -#include +#include #include #include #include diff --git a/src/runtime/opencl/opencl_wrapper/opencl_wrapper.cc b/src/runtime/opencl/opencl_wrapper/opencl_wrapper.cc index e4c4a1a9af31..a10b1a81b837 100644 --- a/src/runtime/opencl/opencl_wrapper/opencl_wrapper.cc +++ b/src/runtime/opencl/opencl_wrapper/opencl_wrapper.cc @@ -33,7 +33,7 @@ #include #endif -#include +#include #include #include diff --git a/src/runtime/rpc/rpc_channel.cc b/src/runtime/rpc/rpc_channel.cc index f462dac3d257..11e14a9a8dbd 100644 --- a/src/runtime/rpc/rpc_channel.cc +++ b/src/runtime/rpc/rpc_channel.cc @@ -22,7 +22,7 @@ */ #include "rpc_channel.h" -#include +#include #include diff --git a/src/runtime/rpc/rpc_server_env.cc b/src/runtime/rpc/rpc_server_env.cc index a51d98b17f93..b8c08b19c413 100644 --- a/src/runtime/rpc/rpc_server_env.cc +++ b/src/runtime/rpc/rpc_server_env.cc @@ -23,6 +23,7 @@ */ #include #include +#include #include "../file_utils.h" diff --git a/src/runtime/static_library.cc b/src/runtime/static_library.cc index e800ed231b24..d3ea7b345838 100644 --- a/src/runtime/static_library.cc +++ b/src/runtime/static_library.cc @@ -31,6 +31,7 @@ #include #include #include +#include #include diff --git a/src/runtime/static_library.h b/src/runtime/static_library.h index 65ee6f8808c8..0ce4d9e003c6 100644 --- a/src/runtime/static_library.h +++ b/src/runtime/static_library.h @@ -28,7 +28,7 @@ #include #include -#include +#include #include #include diff --git a/src/runtime/thread_pool.cc b/src/runtime/thread_pool.cc index 63eba5eba23f..ba2b89770bd7 100644 --- a/src/runtime/thread_pool.cc +++ b/src/runtime/thread_pool.cc @@ -26,7 +26,7 @@ #include #include #include -#include +#include #include "threading_backend.h" #if TVM_THREADPOOL_USE_OPENMP diff --git a/src/runtime/timer.cc b/src/runtime/timer.cc index 075f56337e77..f2adcd353342 100644 --- a/src/runtime/timer.cc +++ b/src/runtime/timer.cc @@ -26,6 +26,7 @@ #include #include #include +#include #include #include diff --git a/src/runtime/vm/attn_backend.h b/src/runtime/vm/attn_backend.h index 5db83ff499e1..ae88843667c3 100644 --- a/src/runtime/vm/attn_backend.h +++ b/src/runtime/vm/attn_backend.h @@ -28,7 +28,7 @@ #include #include #include -#include +#include #include #include diff --git a/src/runtime/vm/bytecode.cc b/src/runtime/vm/bytecode.cc index 4356305521a3..552f19091bc2 100644 --- a/src/runtime/vm/bytecode.cc +++ b/src/runtime/vm/bytecode.cc @@ -22,7 +22,7 @@ * \brief The bytecode for Relax virtual machine. */ -#include +#include #include #include diff --git a/src/runtime/vm/kv_state.h b/src/runtime/vm/kv_state.h index 4578a8a30690..fd001f8048a2 100644 --- a/src/runtime/vm/kv_state.h +++ b/src/runtime/vm/kv_state.h @@ -23,7 +23,7 @@ #include #include #include -#include +#include #include namespace tvm { diff --git a/src/runtime/vm/lm_support.cc b/src/runtime/vm/lm_support.cc index fc1c84cffaef..d07f84be1647 100644 --- a/src/runtime/vm/lm_support.cc +++ b/src/runtime/vm/lm_support.cc @@ -40,7 +40,7 @@ #include #include #include -#include +#include #include #include #include diff --git a/src/runtime/vm/paged_kv_cache.cc b/src/runtime/vm/paged_kv_cache.cc index 879066d9cf30..bb3aee7e340b 100644 --- a/src/runtime/vm/paged_kv_cache.cc +++ b/src/runtime/vm/paged_kv_cache.cc @@ -24,7 +24,7 @@ #include #include #include -#include +#include #include #include diff --git a/src/runtime/vulkan/spirv_shader.h b/src/runtime/vulkan/spirv_shader.h index f290d0dbd195..e9575defd110 100644 --- a/src/runtime/vulkan/spirv_shader.h +++ b/src/runtime/vulkan/spirv_shader.h @@ -23,7 +23,7 @@ #include #include #include -#include +#include #include #include diff --git a/src/runtime/vulkan/vulkan_common.h b/src/runtime/vulkan/vulkan_common.h index d25817a2d787..826048d8578d 100644 --- a/src/runtime/vulkan/vulkan_common.h +++ b/src/runtime/vulkan/vulkan_common.h @@ -23,7 +23,7 @@ #include #include #include -#include +#include #include #include diff --git a/src/runtime/vulkan/vulkan_device.h b/src/runtime/vulkan/vulkan_device.h index 324c3c319cc0..c327149cc2b0 100644 --- a/src/runtime/vulkan/vulkan_device.h +++ b/src/runtime/vulkan/vulkan_device.h @@ -20,7 +20,7 @@ #ifndef TVM_RUNTIME_VULKAN_VULKAN_DEVICE_H_ #define TVM_RUNTIME_VULKAN_VULKAN_DEVICE_H_ -#include +#include #include #include diff --git a/src/runtime/vulkan/vulkan_instance.cc b/src/runtime/vulkan/vulkan_instance.cc index e23e3b7f1ec2..fc88db7644cd 100644 --- a/src/runtime/vulkan/vulkan_instance.cc +++ b/src/runtime/vulkan/vulkan_instance.cc @@ -18,6 +18,7 @@ */ #include "vulkan_instance.h" +#include #include #include diff --git a/src/s_tir/meta_schedule/postproc/rewrite_parallel_vectorize_unroll.cc b/src/s_tir/meta_schedule/postproc/rewrite_parallel_vectorize_unroll.cc index c699ea65136b..27c3ded758ad 100644 --- a/src/s_tir/meta_schedule/postproc/rewrite_parallel_vectorize_unroll.cc +++ b/src/s_tir/meta_schedule/postproc/rewrite_parallel_vectorize_unroll.cc @@ -18,6 +18,7 @@ */ #include #include +#include #include "../utils.h" diff --git a/src/s_tir/meta_schedule/postproc/rewrite_tensorize.cc b/src/s_tir/meta_schedule/postproc/rewrite_tensorize.cc index 958bf5f9227f..01d619302a5a 100644 --- a/src/s_tir/meta_schedule/postproc/rewrite_tensorize.cc +++ b/src/s_tir/meta_schedule/postproc/rewrite_tensorize.cc @@ -19,6 +19,7 @@ #include #include #include +#include #include diff --git a/src/s_tir/meta_schedule/schedule_rule/apply_custom_rule.cc b/src/s_tir/meta_schedule/schedule_rule/apply_custom_rule.cc index 73cae90cd48e..dfd8f99aee8e 100644 --- a/src/s_tir/meta_schedule/schedule_rule/apply_custom_rule.cc +++ b/src/s_tir/meta_schedule/schedule_rule/apply_custom_rule.cc @@ -17,6 +17,7 @@ * under the License. */ #include +#include #include "../utils.h" diff --git a/src/s_tir/meta_schedule/schedule_rule/multi_level_tiling.cc b/src/s_tir/meta_schedule/schedule_rule/multi_level_tiling.cc index 69311ea3c8d5..87244c8809e4 100644 --- a/src/s_tir/meta_schedule/schedule_rule/multi_level_tiling.cc +++ b/src/s_tir/meta_schedule/schedule_rule/multi_level_tiling.cc @@ -21,6 +21,7 @@ #include #include #include +#include #include #include diff --git a/src/s_tir/meta_schedule/space_generator/space_generator.cc b/src/s_tir/meta_schedule/space_generator/space_generator.cc index 9bddc18d839e..da5f5f399833 100644 --- a/src/s_tir/meta_schedule/space_generator/space_generator.cc +++ b/src/s_tir/meta_schedule/space_generator/space_generator.cc @@ -17,6 +17,7 @@ * under the License. */ #include +#include #include "../../../target/canonicalizer/llvm/arm_aprofile.h" #include "../utils.h" diff --git a/src/s_tir/meta_schedule/utils.h b/src/s_tir/meta_schedule/utils.h index 664a00cf3268..5dc99d744c28 100644 --- a/src/s_tir/meta_schedule/utils.h +++ b/src/s_tir/meta_schedule/utils.h @@ -43,6 +43,7 @@ #include #include #include +#include #include #include diff --git a/src/s_tir/schedule/concrete_schedule.cc b/src/s_tir/schedule/concrete_schedule.cc index 89bebd33f833..21f5454040a6 100644 --- a/src/s_tir/schedule/concrete_schedule.cc +++ b/src/s_tir/schedule/concrete_schedule.cc @@ -19,6 +19,7 @@ #include "./concrete_schedule.h" #include +#include #include diff --git a/src/s_tir/schedule/primitive/blockize_tensorize.cc b/src/s_tir/schedule/primitive/blockize_tensorize.cc index 282e167ee55c..da4deb01bc87 100644 --- a/src/s_tir/schedule/primitive/blockize_tensorize.cc +++ b/src/s_tir/schedule/primitive/blockize_tensorize.cc @@ -18,6 +18,7 @@ */ #include +#include #include diff --git a/src/s_tir/schedule/primitive/layout_transformation.cc b/src/s_tir/schedule/primitive/layout_transformation.cc index 4208873b4637..d9c729dd9078 100644 --- a/src/s_tir/schedule/primitive/layout_transformation.cc +++ b/src/s_tir/schedule/primitive/layout_transformation.cc @@ -20,6 +20,7 @@ #include #include #include +#include #include #include diff --git a/src/s_tir/support/parallel_for.h b/src/s_tir/support/parallel_for.h index 9374027c421a..1b2c5fa18fbb 100644 --- a/src/s_tir/support/parallel_for.h +++ b/src/s_tir/support/parallel_for.h @@ -25,7 +25,7 @@ #define TVM_S_TIR_SUPPORT_PARALLEL_FOR_H_ #include -#include +#include #include #include diff --git a/src/s_tir/support/table_printer.h b/src/s_tir/support/table_printer.h index 6ccaa23eca75..eb29b9706e74 100644 --- a/src/s_tir/support/table_printer.h +++ b/src/s_tir/support/table_printer.h @@ -19,7 +19,7 @@ #ifndef TVM_S_TIR_SUPPORT_TABLE_PRINTER_H_ #define TVM_S_TIR_SUPPORT_TABLE_PRINTER_H_ -#include +#include #include #include diff --git a/src/s_tir/transform/inject_double_buffer.cc b/src/s_tir/transform/inject_double_buffer.cc index b91a5214ae95..9c5e9bf0b8b5 100644 --- a/src/s_tir/transform/inject_double_buffer.cc +++ b/src/s_tir/transform/inject_double_buffer.cc @@ -28,6 +28,7 @@ #include #include #include +#include #include "../../tirx/transform/ir_utils.h" diff --git a/src/s_tir/transform/loop_partition.cc b/src/s_tir/transform/loop_partition.cc index e68b465dd263..d47c861873a7 100644 --- a/src/s_tir/transform/loop_partition.cc +++ b/src/s_tir/transform/loop_partition.cc @@ -31,6 +31,7 @@ #include #include #include +#include #include #include diff --git a/src/s_tir/transform/lower_async_dma.cc b/src/s_tir/transform/lower_async_dma.cc index 628e20be88bc..756461b0dd08 100644 --- a/src/s_tir/transform/lower_async_dma.cc +++ b/src/s_tir/transform/lower_async_dma.cc @@ -32,6 +32,7 @@ #include #include #include +#include #include #include diff --git a/src/script/ir_builder/ir/ir.cc b/src/script/ir_builder/ir/ir.cc index 153c42e9d14a..683806768dc2 100644 --- a/src/script/ir_builder/ir/ir.cc +++ b/src/script/ir_builder/ir/ir.cc @@ -20,6 +20,7 @@ #include #include #include +#include #include "./utils.h" diff --git a/src/support/base64.h b/src/support/base64.h index af011e317538..36cb81cb5447 100644 --- a/src/support/base64.h +++ b/src/support/base64.h @@ -26,7 +26,7 @@ #ifndef TVM_SUPPORT_BASE64_H_ #define TVM_SUPPORT_BASE64_H_ -#include +#include #include #include diff --git a/src/support/pipe.h b/src/support/pipe.h index ec7a8ea14d9e..c208d223bd61 100644 --- a/src/support/pipe.h +++ b/src/support/pipe.h @@ -24,7 +24,7 @@ #ifndef TVM_SUPPORT_PIPE_H_ #define TVM_SUPPORT_PIPE_H_ -#include +#include #include #ifdef _WIN32 diff --git a/src/support/ring_buffer.h b/src/support/ring_buffer.h index 40d741a762e4..912b4c8d4b46 100644 --- a/src/support/ring_buffer.h +++ b/src/support/ring_buffer.h @@ -24,7 +24,7 @@ #ifndef TVM_SUPPORT_RING_BUFFER_H_ #define TVM_SUPPORT_RING_BUFFER_H_ -#include +#include #include #include diff --git a/src/target/canonicalizer/llvm/arm_aprofile.cc b/src/target/canonicalizer/llvm/arm_aprofile.cc index 97f0071394a9..0ad87ad66bf8 100644 --- a/src/target/canonicalizer/llvm/arm_aprofile.cc +++ b/src/target/canonicalizer/llvm/arm_aprofile.cc @@ -24,6 +24,8 @@ #include "arm_aprofile.h" +#include + #include #include diff --git a/src/target/cuda/codegen_cuda.cc b/src/target/cuda/codegen_cuda.cc index d76e5fbac187..ec5f014e8e0b 100644 --- a/src/target/cuda/codegen_cuda.cc +++ b/src/target/cuda/codegen_cuda.cc @@ -31,6 +31,7 @@ #include #include +#include #include #include #include diff --git a/src/target/cuda/ptx.h b/src/target/cuda/ptx.h index b82a9c6ad3f3..7bdc16e3ae0c 100644 --- a/src/target/cuda/ptx.h +++ b/src/target/cuda/ptx.h @@ -24,7 +24,7 @@ #ifndef TVM_TARGET_SOURCE_PTX_H_ #define TVM_TARGET_SOURCE_PTX_H_ -#include +#include #include #include diff --git a/src/target/hexagon/llvm/codegen_hexagon.cc b/src/target/hexagon/llvm/codegen_hexagon.cc index a7dbda398d8f..c83af58c4ce7 100644 --- a/src/target/hexagon/llvm/codegen_hexagon.cc +++ b/src/target/hexagon/llvm/codegen_hexagon.cc @@ -45,6 +45,7 @@ #include #include #include +#include #include #include diff --git a/src/target/intrin_rule.cc b/src/target/intrin_rule.cc index 31e8b6a83290..9e1a8ce068cc 100644 --- a/src/target/intrin_rule.cc +++ b/src/target/intrin_rule.cc @@ -26,6 +26,7 @@ #include #include #include +#include namespace tvm { namespace codegen { diff --git a/src/target/llvm/codegen_aarch64.cc b/src/target/llvm/codegen_aarch64.cc index 7c328f18ab12..18da2e66d7a8 100644 --- a/src/target/llvm/codegen_aarch64.cc +++ b/src/target/llvm/codegen_aarch64.cc @@ -27,6 +27,7 @@ #include #include #include +#include #include "../../arith/scalable_expression.h" #include "codegen_cpu.h" diff --git a/src/target/llvm/codegen_cpu.cc b/src/target/llvm/codegen_cpu.cc index 974cbb1e8a5c..09308a6ebbfd 100644 --- a/src/target/llvm/codegen_cpu.cc +++ b/src/target/llvm/codegen_cpu.cc @@ -51,6 +51,7 @@ #include #include #include +#include #include #include diff --git a/src/target/metal/codegen_metal.cc b/src/target/metal/codegen_metal.cc index e0cc10fe3e5c..c84df824a14f 100644 --- a/src/target/metal/codegen_metal.cc +++ b/src/target/metal/codegen_metal.cc @@ -26,6 +26,7 @@ #include #include #include +#include #include #include diff --git a/src/target/rocm/llvm/codegen_amdgpu.cc b/src/target/rocm/llvm/codegen_amdgpu.cc index ef31805be7d2..2da399231e31 100644 --- a/src/target/rocm/llvm/codegen_amdgpu.cc +++ b/src/target/rocm/llvm/codegen_amdgpu.cc @@ -47,6 +47,7 @@ #include #include #include +#include #include "../../../runtime/metadata.h" #include "../../build_common.h" diff --git a/src/target/target_kind.cc b/src/target/target_kind.cc index 8d4cd50a9a6c..e593852e43ad 100644 --- a/src/target/target_kind.cc +++ b/src/target/target_kind.cc @@ -27,6 +27,7 @@ #include #include #include +#include #include diff --git a/src/tirx/transform/lower_intrin.cc b/src/tirx/transform/lower_intrin.cc index 772bdd4d5b2a..a8c60d33b9d2 100644 --- a/src/tirx/transform/lower_intrin.cc +++ b/src/tirx/transform/lower_intrin.cc @@ -29,6 +29,7 @@ #include #include #include +#include #include #include diff --git a/src/tirx/transform/lower_tvm_builtin.cc b/src/tirx/transform/lower_tvm_builtin.cc index 8e71d2f26f40..3ba72294bb2c 100644 --- a/src/tirx/transform/lower_tvm_builtin.cc +++ b/src/tirx/transform/lower_tvm_builtin.cc @@ -29,6 +29,7 @@ #include #include #include +#include #include diff --git a/src/tirx/transform/storage_rewrite.cc b/src/tirx/transform/storage_rewrite.cc index 99d62ab02b20..f64d262f97c8 100644 --- a/src/tirx/transform/storage_rewrite.cc +++ b/src/tirx/transform/storage_rewrite.cc @@ -33,6 +33,7 @@ #include #include #include +#include #include #include diff --git a/src/tirx/transform/tvm_ffi_binder.cc b/src/tirx/transform/tvm_ffi_binder.cc index 4a0c2d3e124c..881d94ba61ab 100644 --- a/src/tirx/transform/tvm_ffi_binder.cc +++ b/src/tirx/transform/tvm_ffi_binder.cc @@ -29,6 +29,7 @@ #include #include #include +#include #include "ir_utils.h" @@ -107,7 +108,7 @@ void TVMFFIABIBuilder::EmitTypeIndexCheck(int param_index, const PrimExpr& cond, // RenderAccessPath // ============================================================ -ffi::String TVMFFIABIBuilder::RenderAccessPath(const AccessPath& path) const { +ffi::String TVMFFIABIBuilder::RenderAccessPath(const ffi::reflection::AccessPath& path) const { ffi::Array steps = path->ToSteps(); std::ostringstream os; bool first_printed = false; @@ -148,7 +149,7 @@ ffi::String TVMFFIABIBuilder::RenderAccessPath(const AccessPath& path) const { // GetParamIndex // ============================================================ -int TVMFFIABIBuilder::GetParamIndex(const AccessPath& path) const { +int TVMFFIABIBuilder::GetParamIndex(const ffi::reflection::AccessPath& path) const { ffi::Array steps = path->ToSteps(); if (steps.size() >= 1 && steps[0]->kind == ffi::reflection::AccessKind::kArrayItem) { return static_cast(steps[0]->key.cast()); @@ -157,11 +158,11 @@ int TVMFFIABIBuilder::GetParamIndex(const AccessPath& path) const { } // ============================================================ -// BindScalar (scalar bind with AccessPath) +// BindScalar (scalar bind with ffi::reflection::AccessPath) // ============================================================ bool TVMFFIABIBuilder::BindScalar(const PrimExpr& arg, const PrimExpr& value, - const AccessPath& path, bool with_lets) { + const ffi::reflection::AccessPath& path, bool with_lets) { TVM_FFI_ICHECK_EQ(arg.dtype(), value.dtype()); if (arg.as()) { Var v_arg = Downcast(arg); @@ -225,7 +226,7 @@ bool TVMFFIABIBuilder::BindScalar(const PrimExpr& arg, const PrimExpr& value, // ============================================================ /*! - * \brief Render PrimExpr to string with variable names replaced by AccessPath names. + * \brief Render PrimExpr to string with variable names replaced by ffi::reflection::AccessPath names. * * Uses ExprFunctor for generic dispatch over all expression types. * The default TIR printer sanitizes Var name_hints (e.g. "B.shape[0]" -> "B_shape_0_") @@ -324,24 +325,24 @@ TVMFFIABIBuilder::Result TVMFFIABIBuilder::Finalize() { } // ============================================================ -// BindArray (array bind with AccessPath) +// BindArray (array bind with ffi::reflection::AccessPath) // ============================================================ void TVMFFIABIBuilder::BindArray(const ffi::Array& arg, const ffi::Array& value, - const AccessPath& base_path) { + const ffi::reflection::AccessPath& base_path) { TVM_FFI_ICHECK_EQ(arg.size(), value.size()) << "Array size mismatch at " << RenderAccessPath(base_path); for (size_t i = 0; i < arg.size(); ++i) { - AccessPath elem_path = base_path->ArrayItem(i); + ffi::reflection::AccessPath elem_path = base_path->ArrayItem(i); BindScalar(arg[i], value[i], elem_path, false); } } // ============================================================ -// BindBuffer (buffer-to-buffer bind with AccessPath) +// BindBuffer (buffer-to-buffer bind with ffi::reflection::AccessPath) // ============================================================ -void TVMFFIABIBuilder::BindBuffer(const Buffer& arg, const Buffer& value, AccessPath base_path, +void TVMFFIABIBuilder::BindBuffer(const Buffer& arg, const Buffer& value, ffi::reflection::AccessPath base_path, bool fuzzy_match) { TVM_FFI_ICHECK_EQ(arg.scope(), value.scope()) << "Argument " << arg->name << " Buffer bind scope mismatch"; @@ -360,9 +361,9 @@ void TVMFFIABIBuilder::BindBuffer(const Buffer& arg, const Buffer& value, Access << " required elem_offset=" << arg->elem_offset << ", provided elem_offset=" << value->elem_offset; } - AccessPath data_path = base_path->Attr(ffi::String("data")); + ffi::reflection::AccessPath data_path = base_path->Attr(ffi::String("data")); BindScalar(arg->data, value->data, data_path, false); - AccessPath offset_path = base_path->Attr(ffi::String("elem_offset")); + ffi::reflection::AccessPath offset_path = base_path->Attr(ffi::String("elem_offset")); if (BindScalar(arg->elem_offset, value->elem_offset, offset_path, false)) { if (arg->offset_factor > 1) { PrimExpr offset = value->elem_offset; @@ -385,8 +386,8 @@ void TVMFFIABIBuilder::BindBuffer(const Buffer& arg, const Buffer& value, Access } } - AccessPath shape_path = base_path->Attr(ffi::String("shape")); - AccessPath strides_path = base_path->Attr(ffi::String("strides")); + ffi::reflection::AccessPath shape_path = base_path->Attr(ffi::String("shape")); + ffi::reflection::AccessPath strides_path = base_path->Attr(ffi::String("strides")); if (arg->shape.size() < value->shape.size()) { TVM_FFI_ICHECK(fuzzy_match) << "Buffer size mismatch at " << RenderAccessPath(base_path); @@ -397,14 +398,14 @@ void TVMFFIABIBuilder::BindBuffer(const Buffer& arg, const Buffer& value, Access << " vs " << value->shape; } for (size_t i = 0; i < arg->shape.size(); ++i) { - AccessPath shape_k_path = shape_path->ArrayItem(i); + ffi::reflection::AccessPath shape_k_path = shape_path->ArrayItem(i); BindScalar(arg->shape[i], value->shape[i + diff], shape_k_path, false); } if (value->strides.size() != 0) { TVM_FFI_ICHECK_EQ(arg->strides.size(), arg->shape.size()); TVM_FFI_ICHECK_EQ(value->strides.size(), value->shape.size()); for (size_t i = 0; i < arg->strides.size(); ++i) { - AccessPath strides_k_path = strides_path->ArrayItem(i); + ffi::reflection::AccessPath strides_k_path = strides_path->ArrayItem(i); BindScalar(arg->strides[i], value->strides[i + diff], strides_k_path, false); } } @@ -513,7 +514,7 @@ void TVMFFIABIBuilder::DecodeParam(int param_index) { } // Bind scalar param to loaded value (defines vars before buffer binds reference them) - AccessPath param_path = AccessPath::Root()->Extend(AccessStep::ArrayItem(param_index)); + ffi::reflection::AccessPath param_path = ffi::reflection::AccessPath::Root()->Extend(AccessStep::ArrayItem(param_index)); BindScalar(param, arg_value, param_path, true); } @@ -535,8 +536,8 @@ void TVMFFIABIBuilder::DecodeAllParams() { Var param = params_[i]; if (buffer_map_.count(param)) { Buffer buffer = buffer_map_[param]; - AccessPath param_path = - AccessPath::Root()->Extend(AccessStep::ArrayItem(i))->Attr(ffi::String(buffer->name)); + ffi::reflection::AccessPath param_path = + ffi::reflection::AccessPath::Root()->Extend(AccessStep::ArrayItem(i))->Attr(ffi::String(buffer->name)); DecodeParamDLTensor(buffer, device_type_, device_id_, param, func_name_ + "." + param->name_hint, param_path); decl_buffers_.push_back(DeclBuffer(buffer)); @@ -571,7 +572,7 @@ PrimExpr TVMFFIABIBuilder::LoadInt64ArrayElem(const Var& ptr, int index) { void TVMFFIABIBuilder::BindCompactStrides(const Buffer& buffer, const Var& strides_ptr, const PrimExpr& v_strides_is_null, - const AccessPath& param_path) { + const ffi::reflection::AccessPath& param_path) { DataType stype = buffer->DefaultIndexType(); PrimExpr expect_stride = make_const(stype, 1); ffi::Array conds; @@ -598,7 +599,7 @@ void TVMFFIABIBuilder::BindCompactStrides(const Buffer& buffer, const Var& strid void TVMFFIABIBuilder::BindAutoBroadcastStrides(const Buffer& buffer, const Var& strides_ptr, const PrimExpr& v_strides_is_null, - const AccessPath& param_path) { + const ffi::reflection::AccessPath& param_path) { DataType stype = buffer->DefaultIndexType(); PrimExpr stride = make_const(stype, 1); for (size_t i = buffer->shape.size(); i != 0; --i) { @@ -606,7 +607,7 @@ void TVMFFIABIBuilder::BindAutoBroadcastStrides(const Buffer& buffer, const Var& PrimExpr value = cast(buffer->shape[k].dtype(), LoadInt64ArrayElem(strides_ptr, k)); value = tvm::if_then_else(v_strides_is_null, stride, value); value = tvm::if_then_else(buffer->shape[k] == 1, 0, value); - AccessPath strides_k_path = param_path->Attr(ffi::String("strides"))->ArrayItem(k); + ffi::reflection::AccessPath strides_k_path = param_path->Attr(ffi::String("strides"))->ArrayItem(k); BindScalar(buffer->strides[k], value, strides_k_path, true); stride = analyzer_.Simplify(stride * buffer->shape[k]); } @@ -614,11 +615,11 @@ void TVMFFIABIBuilder::BindAutoBroadcastStrides(const Buffer& buffer, const Var& void TVMFFIABIBuilder::BindRegularStrides(const Buffer& buffer, const Var& strides_ptr, const Var& shape_ptr, const PrimExpr& v_strides_is_null, - const AccessPath& param_path) { + const ffi::reflection::AccessPath& param_path) { PrimExpr stride_from_shape = 1; for (int k = buffer->strides.size() - 1; k >= 0; k--) { PrimExpr explicit_stride = cast(buffer->shape[k].dtype(), LoadInt64ArrayElem(strides_ptr, k)); - AccessPath strides_k_path = param_path->Attr(ffi::String("strides"))->ArrayItem(k); + ffi::reflection::AccessPath strides_k_path = param_path->Attr(ffi::String("strides"))->ArrayItem(k); BindScalar(buffer->strides[k], tvm::if_then_else(v_strides_is_null, stride_from_shape, explicit_stride), strides_k_path, true); @@ -632,11 +633,11 @@ void TVMFFIABIBuilder::BindRegularStrides(const Buffer& buffer, const Var& strid void TVMFFIABIBuilder::DecodeParamDLTensor(const Buffer& buffer, const PrimExpr& device_type, const PrimExpr& device_id, const Var& handle, - const std::string& arg_name, AccessPath base_path) { + const std::string& arg_name, ffi::reflection::AccessPath base_path) { const DataType tvm_ndim_type = DataType::Int(32); std::string buf_name = buffer->name; - AccessPath param_path = base_path; + ffi::reflection::AccessPath param_path = base_path; int param_index = GetParamIndex(base_path); // ── Section: Null pointer check ────────────────────────────── @@ -675,7 +676,7 @@ void TVMFFIABIBuilder::DecodeParamDLTensor(const Buffer& buffer, const PrimExpr& buffer->dtype == DataType::Int(1)) { break; } - AccessPath shape_k_path = param_path->Attr(ffi::String("shape"))->ArrayItem(k); + ffi::reflection::AccessPath shape_k_path = param_path->Attr(ffi::String("shape"))->ArrayItem(k); BindScalar(buffer->shape[k], cast(buffer->shape[k].dtype(), LoadInt64ArrayElem(shape_ptr, k)), shape_k_path, true); } @@ -693,7 +694,7 @@ void TVMFFIABIBuilder::DecodeParamDLTensor(const Buffer& buffer, const PrimExpr& // ── Section: byte_offset ───────────────────────────────────── int data_bytes = GetVectorBytes(buffer->dtype); - AccessPath byte_offset_path = param_path->Attr(ffi::String("byte_offset")); + ffi::reflection::AccessPath byte_offset_path = param_path->Attr(ffi::String("byte_offset")); if (const auto* const_offset = buffer->elem_offset.as()) { BindScalar(make_const(DataType::UInt(64), const_offset->value * data_bytes), TVMStructGet(DataType::UInt(64), handle, 0, builtin::kDLTensorByteOffset), @@ -739,17 +740,17 @@ void TVMFFIABIBuilder::DecodeParamDLTensor(const Buffer& buffer, const PrimExpr& device_name); } } else { - AccessPath device_type_path = param_path->Attr(ffi::String("device_type")); + ffi::reflection::AccessPath device_type_path = param_path->Attr(ffi::String("device_type")); BindScalar(device_type_, actual_device_type, device_type_path, true); } - AccessPath device_id_path = param_path->Attr(ffi::String("device_id")); + ffi::reflection::AccessPath device_id_path = param_path->Attr(ffi::String("device_id")); BindScalar(device_id_, TVMStructGet(DataType::Int(32), handle, 0, builtin::kDLTensorDeviceId), device_id_path, true); } // ── Section: data pointer ──────────────────────────────────── { - AccessPath data_path = param_path->Attr(ffi::String("data")); + ffi::reflection::AccessPath data_path = param_path->Attr(ffi::String("data")); if (BindScalar(buffer->data, TVMStructGet(DataType::Handle(), handle, 0, builtin::kDLTensorData), data_path, true)) { diff --git a/src/tirx/transform/tvm_ffi_binder.h b/src/tirx/transform/tvm_ffi_binder.h index 5f17d970dd8f..03ed0b77fede 100644 --- a/src/tirx/transform/tvm_ffi_binder.h +++ b/src/tirx/transform/tvm_ffi_binder.h @@ -49,7 +49,7 @@ namespace tirx { * generation for packed function parameters. The primary public method is * DecodeAllParams(), which handles everything: type index extraction, * type checking (TypeError), value loading, scalar binding, buffer - * binding, and rich error message generation with AccessPath. + * binding, and rich error message generation with ffi::reflection::AccessPath. * * ## Generated statement ordering * @@ -85,7 +85,7 @@ namespace tirx { */ class TVMFFIABIBuilder { public: - /*! \brief Variable definition info: bound value and the AccessPath where first defined. */ + /*! \brief Variable definition info: bound value and the ffi::reflection::AccessPath where first defined. */ struct VarDefInfo { PrimExpr value; ffi::reflection::AccessPath first_def_path; @@ -221,10 +221,10 @@ class TVMFFIABIBuilder { */ PrimExpr DecodeParamFloat(int param_index, const Var& type_index, DataType dtype); - // ── Private binding submethods (all take AccessPath) ─────────── + // ── Private binding submethods (all take ffi::reflection::AccessPath) ─────────── /*! - * \brief Internal scalar bind with AccessPath tracking and rich error messages. + * \brief Internal scalar bind with ffi::reflection::AccessPath tracking and rich error messages. * * Binds \p arg to \p value. If arg is a Var not yet in var_defs_, creates a * new definition (Bind to init_nest_); otherwise emits a rich assertion @@ -232,36 +232,36 @@ class TVMFFIABIBuilder { * * When arg is a non-Var expression (e.g. batch_size + 1), the assertion is * deferred to Finalize() so display-var substitution can render the expression - * using AccessPath names (e.g. "k.shape[0] + 1" instead of "batch_size + 1"). + * using ffi::reflection::AccessPath names (e.g. "k.shape[0] + 1" instead of "batch_size + 1"). * * \param arg The argument expression to bind (typically a Var or constant). * \param value The value expression to bind to the argument. * \param with_lets If true, emit Bind bindings into init_nest_. - * \param path AccessPath for rich error message rendering. + * \param path ffi::reflection::AccessPath for rich error message rendering. * \return True if this was the first bind (definition created), false otherwise. */ bool BindScalar(const PrimExpr& arg, const PrimExpr& value, const ffi::reflection::AccessPath& path, bool with_lets); /*! - * \brief Array bind: binds element-wise with AccessPath[k] for each element. + * \brief Array bind: binds element-wise with ffi::reflection::AccessPath[k] for each element. * * \param arg The expected array of expressions. * \param value The actual array of expressions to bind against. - * \param base_path Base AccessPath; each element appends ArrayItem(k). + * \param base_path Base ffi::reflection::AccessPath; each element appends ArrayItem(k). */ void BindArray(const ffi::Array& arg, const ffi::Array& value, const ffi::reflection::AccessPath& base_path); /*! - * \brief Buffer-to-buffer bind with AccessPath. + * \brief Buffer-to-buffer bind with ffi::reflection::AccessPath. * * Binds data, elem_offset, shape, and strides of \p arg against \p value, * emitting assertions for any mismatches. * * \param arg The expected buffer definition. * \param value The actual buffer to bind against. - * \param base_path Base AccessPath for the buffer parameter. + * \param base_path Base ffi::reflection::AccessPath for the buffer parameter. * \param fuzzy_match If true, allow value to have more dimensions than arg. */ void BindBuffer(const Buffer& arg, const Buffer& value, ffi::reflection::AccessPath base_path, @@ -275,7 +275,7 @@ class TVMFFIABIBuilder { * \param device_id The expected device id expression. * \param handle The variable holding the DLTensor handle. * \param arg_name Human-readable name for error messages. - * \param base_path Base AccessPath for the tensor parameter. + * \param base_path Base ffi::reflection::AccessPath for the tensor parameter. */ void DecodeParamDLTensor(const Buffer& buffer, const PrimExpr& device_type, const PrimExpr& device_id, const Var& handle, @@ -310,7 +310,7 @@ class TVMFFIABIBuilder { * \param buffer The expected buffer definition. * \param strides_ptr The strides pointer variable. * \param v_strides_is_null Expression checking if strides pointer is NULL. - * \param param_path AccessPath for the tensor parameter. + * \param param_path ffi::reflection::AccessPath for the tensor parameter. */ void BindCompactStrides(const Buffer& buffer, const Var& strides_ptr, const PrimExpr& v_strides_is_null, @@ -322,7 +322,7 @@ class TVMFFIABIBuilder { * \param buffer The expected buffer definition. * \param strides_ptr The strides pointer variable. * \param v_strides_is_null Expression checking if strides pointer is NULL. - * \param param_path AccessPath for the tensor parameter. + * \param param_path ffi::reflection::AccessPath for the tensor parameter. */ void BindAutoBroadcastStrides(const Buffer& buffer, const Var& strides_ptr, const PrimExpr& v_strides_is_null, @@ -335,7 +335,7 @@ class TVMFFIABIBuilder { * \param strides_ptr The strides pointer variable. * \param shape_ptr The shape pointer variable (for computing C-contiguous strides). * \param v_strides_is_null Expression checking if strides pointer is NULL. - * \param param_path AccessPath for the tensor parameter. + * \param param_path ffi::reflection::AccessPath for the tensor parameter. */ void BindRegularStrides(const Buffer& buffer, const Var& strides_ptr, const Var& shape_ptr, const PrimExpr& v_strides_is_null, @@ -356,15 +356,15 @@ class TVMFFIABIBuilder { void EmitTypeIndexCheck(int param_index, const PrimExpr& cond, const std::string& expected_type); /*! - * \brief Render an AccessPath as a human-readable string (e.g. "a.shape[0]"). - * \param path The AccessPath to render. + * \brief Render an ffi::reflection::AccessPath as a human-readable string (e.g. "a.shape[0]"). + * \param path The ffi::reflection::AccessPath to render. * \return A human-readable string representation of the path. */ ffi::String RenderAccessPath(const ffi::reflection::AccessPath& path) const; /*! * \brief Extract param_index from the root ArrayItem step of a path. - * \param path The AccessPath to extract the index from. + * \param path The ffi::reflection::AccessPath to extract the index from. * \return The param index, or -1 if not found. */ int GetParamIndex(const ffi::reflection::AccessPath& path) const; @@ -373,7 +373,7 @@ class TVMFFIABIBuilder { * \brief Render pending constant-expression assertions with display-var substitution. * * For each pending assertion, substitutes known variable names with their - * AccessPath-rendered names (e.g. batch_size → "k.shape[0]") so error messages + * ffi::reflection::AccessPath-rendered names (e.g. batch_size → "k.shape[0]") so error messages * show human-readable expressions like "k.shape[0] + 1" instead of "batch_size + 1". */ void RenderPendingAsserts(); @@ -416,7 +416,7 @@ class TVMFFIABIBuilder { PrimExpr device_type_; /*! \brief The device id variable. */ PrimExpr device_id_; - /*! \brief Map from param_index to param_name for AccessPath rendering. */ + /*! \brief Map from param_index to param_name for ffi::reflection::AccessPath rendering. */ std::unordered_map param_names_; // Pre-cached common message fragments for string sharing across assertions diff --git a/src/tirx/transform/vectorize_loop.cc b/src/tirx/transform/vectorize_loop.cc index 104c111c722b..45d6a5e118be 100644 --- a/src/tirx/transform/vectorize_loop.cc +++ b/src/tirx/transform/vectorize_loop.cc @@ -33,6 +33,7 @@ #include #include #include +#include #include #include From 9e3398d4781a84d3b02436360ee8ec4eba1e7fea Mon Sep 17 00:00:00 2001 From: tqchen Date: Thu, 30 Apr 2026 13:47:24 +0000 Subject: [PATCH 3/4] [REFACTOR][FFI] Fix cpptest unqualified references to ffi types Commit 1 of this PR removed `using ffi::IsContiguous;`, `using ffi::DLDataTypeToString;`, and similar free-function pull-ins from public headers. The cpptest sources in `tests/cpp/` still referenced those names unqualified through `tvm::runtime::*`, which broke the cpu/wasm CI builds. Qualify all such references in `tests/cpp/` with the canonical `ffi::*` namespace. Files: tests/cpp/ndarray_test.cc, tests/cpp/tir_scalable_datatype.cc --- tests/cpp/ndarray_test.cc | 8 ++++---- tests/cpp/tir_scalable_datatype.cc | 2 +- 2 files changed, 5 insertions(+), 5 deletions(-) diff --git a/tests/cpp/ndarray_test.cc b/tests/cpp/ndarray_test.cc index fdb064b4a46b..c02efecc5148 100644 --- a/tests/cpp/ndarray_test.cc +++ b/tests/cpp/ndarray_test.cc @@ -30,7 +30,7 @@ TEST(TensorTest, IsContiguous_ContiguousStride) { int64_t strides[] = {10, 1}; managed_tensor->dl_tensor.strides = strides; - TVM_FFI_ICHECK(runtime::IsContiguous(managed_tensor->dl_tensor)); + TVM_FFI_ICHECK(ffi::IsContiguous(managed_tensor->dl_tensor)); managed_tensor->deleter(managed_tensor); } @@ -41,7 +41,7 @@ TEST(TensorTest, IsContiguous_NullStride) { managed_tensor->dl_tensor.strides = nullptr; - TVM_FFI_ICHECK(runtime::IsContiguous(managed_tensor->dl_tensor)); + TVM_FFI_ICHECK(ffi::IsContiguous(managed_tensor->dl_tensor)); managed_tensor->deleter(managed_tensor); } @@ -53,7 +53,7 @@ TEST(TensorTest, IsContiguous_AnyStrideForSingular) { int64_t strides[] = {10, 1, 1}; // strides[1] is normalized to 1 because shape[1] == 1. managed_tensor->dl_tensor.strides = strides; - TVM_FFI_ICHECK(runtime::IsContiguous(managed_tensor->dl_tensor)); + TVM_FFI_ICHECK(ffi::IsContiguous(managed_tensor->dl_tensor)); managed_tensor->dl_tensor.strides = nullptr; managed_tensor->deleter(managed_tensor); @@ -66,7 +66,7 @@ TEST(TensorTest, IsContiguous_UncontiguousStride) { int64_t strides[] = {1, 1, 1}; managed_tensor->dl_tensor.strides = strides; - TVM_FFI_ICHECK(!runtime::IsContiguous(managed_tensor->dl_tensor)); + TVM_FFI_ICHECK(!ffi::IsContiguous(managed_tensor->dl_tensor)); managed_tensor->dl_tensor.strides = nullptr; managed_tensor->deleter(managed_tensor); diff --git a/tests/cpp/tir_scalable_datatype.cc b/tests/cpp/tir_scalable_datatype.cc index 9be9e8552e83..fd9f76eee366 100644 --- a/tests/cpp/tir_scalable_datatype.cc +++ b/tests/cpp/tir_scalable_datatype.cc @@ -83,7 +83,7 @@ TEST(ScalableDataType, TestIsScalar) { TEST(ScalableDataType, TestScalableDataTypeToString) { tvm::DataType scalable_type = tvm::DataType(kDLInt, 32, 4, true); - EXPECT_EQ(tvm::runtime::DLDataTypeToString(scalable_type), "int32xvscalex4"); + EXPECT_EQ(tvm::ffi::DLDataTypeToString(scalable_type), "int32xvscalex4"); } TEST(ScalableDataType, TestStringToScalableDataType) { From 97a923242471a1afc75ec4ae3414fe509d4bb60a Mon Sep 17 00:00:00 2001 From: tqchen Date: Thu, 30 Apr 2026 14:15:55 +0000 Subject: [PATCH 4/4] [REFACTOR][FFI] Re-keep semantically-meaningful type aliases MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit After the previous commit deleted all ffi:: indirection aliases, review identified 9 aliases that carry distinct semantic meaning beyond a mere namespace shortcut and should be preserved: Region = ffi::Array (tirx::, arith::) MemoryScope = ffi::String (global_info.h, virtual_device.h) TGlobalSymbol = ffi::String (tirx/op_attr_types.h) TScriptPrinterName = ffi::String (tirx/op_attr_types.h) FCallPacked = ffi::String (relax/op_attr_types.h) AccessPath = ffi::reflection::AccessPath (script/printer namespace only) Re-add these aliases in the relevant headers and revert all call-site rewrites (template arguments, variable types, return types) to use the alias names again. String key literals (e.g., "FCallPacked") are left unchanged — they are registry lookup keys, not type names. Files outside the printer namespace (src/s_tir/analysis/is_pure_function.cc, src/ir/script_printer.cc) keep the fully-qualified ffi::reflection::AccessPath. --- include/tvm/arith/bound.h | 5 +- include/tvm/ir/global_info.h | 9 +- include/tvm/relax/attrs/op.h | 2 +- include/tvm/relax/op_attr_types.h | 6 + include/tvm/script/printer/doc.h | 22 +-- include/tvm/script/printer/ir_docsifier.h | 14 +- include/tvm/target/virtual_device.h | 21 ++- include/tvm/tirx/op.h | 2 +- include/tvm/tirx/op_attr_types.h | 10 ++ include/tvm/tirx/var.h | 2 +- src/arith/domain_touched.cc | 8 +- src/ir/global_info.cc | 4 +- src/relax/backend/vm/codegen_vm.cc | 2 +- src/relax/op/op.cc | 6 +- src/relax/op/tensor/set.cc | 4 +- src/relax/script/printer/binding.cc | 14 +- src/relax/script/printer/call.cc | 29 ++- src/relax/script/printer/distributed.cc | 13 +- src/relax/script/printer/expr.cc | 23 ++- src/relax/script/printer/function.cc | 165 +++++++++--------- src/relax/script/printer/region.cc | 24 ++- src/relax/script/printer/struct_info.cc | 23 ++- src/relax/script/printer/tir.cc | 29 ++- src/relax/script/printer/type.cc | 16 +- src/relax/script/printer/utils.h | 8 +- src/relax/transform/legalize_ops.cc | 2 +- src/s_tir/schedule/primitive/cache_index.cc | 2 +- .../schedule/primitive/cache_read_write.cc | 30 ++-- .../schedule/primitive/decompose_padding.cc | 2 +- .../schedule/primitive/rolling_buffer.cc | 4 +- src/s_tir/support/nd_int_set.h | 2 +- src/s_tir/transform/compact_buffer_region.cc | 16 +- .../transform/inject_software_pipeline.cc | 6 +- src/s_tir/transform/lower_match_buffer.cc | 4 +- .../transform/memhammer_lower_auto_copy.cc | 2 +- src/script/printer/doc.cc | 6 +- .../printer/doc_printer/base_doc_printer.cc | 8 +- .../printer/doc_printer/base_doc_printer.h | 4 +- src/script/printer/ir/distributed.cc | 2 +- src/script/printer/ir/ir.cc | 16 +- src/script/printer/ir/misc.cc | 4 +- src/script/printer/ir_docsifier.cc | 4 +- src/script/printer/utils.h | 4 +- src/target/cuda/intrin_rule_cuda.cc | 8 +- src/target/llvm/codegen_llvm.h | 2 +- src/target/metal/intrin_rule_metal.cc | 6 +- src/target/source/codegen_c.h | 2 +- src/target/virtual_device.cc | 8 +- src/target/webgpu/intrin_rule_webgpu.cc | 6 +- src/tirx/op/builtin.cc | 10 +- src/tirx/op/op.cc | 6 +- src/tirx/op/runtime.cc | 4 +- src/tirx/script/printer/block.cc | 20 +-- src/tirx/script/printer/buffer.cc | 40 ++--- src/tirx/script/printer/expr.cc | 52 +++--- src/tirx/script/printer/for_loop.cc | 2 +- src/tirx/script/printer/function.cc | 18 +- src/tirx/script/printer/ir.cc | 14 +- src/tirx/script/printer/stmt.cc | 34 ++-- src/tirx/script/printer/utils.h | 8 +- src/tirx/transform/ir_utils.cc | 4 +- src/tirx/transform/ir_utils.h | 2 +- 62 files changed, 417 insertions(+), 408 deletions(-) diff --git a/include/tvm/arith/bound.h b/include/tvm/arith/bound.h index a286103bed8c..7ae36fb289f0 100644 --- a/include/tvm/arith/bound.h +++ b/include/tvm/arith/bound.h @@ -33,6 +33,7 @@ namespace tvm { namespace arith { +using tirx::Region; using tirx::Stmt; using tirx::Var; using tirx::VarNode; @@ -76,8 +77,8 @@ IntSet DeduceBound(PrimExpr v, PrimExpr cond, * \param consider_stores If stores are considered. * \return The domain that covers all the calls or provides within the given statement. */ -ffi::Array DomainTouched(const Stmt& body, const tirx::Buffer& buffer, bool consider_loads, - bool consider_stores); +Region DomainTouched(const Stmt& body, const tirx::Buffer& buffer, bool consider_loads, + bool consider_stores); } // namespace arith } // namespace tvm diff --git a/include/tvm/ir/global_info.h b/include/tvm/ir/global_info.h index 1fb87654d0c1..3533b7868732 100644 --- a/include/tvm/ir/global_info.h +++ b/include/tvm/ir/global_info.h @@ -31,6 +31,11 @@ namespace tvm { +/*! + * \brief Abstract label for an area of memory. + */ +using MemoryScope = ffi::String; + /*! * \brief GlobalInfo are globally static object that are referred by the IR itself. * Base node for all global info that can appear in the IR @@ -62,7 +67,7 @@ class VDeviceNode : public GlobalInfoNode { * differentiate between distinct devices with same Target, such as multiple GPUs. */ int vdevice_id; - ffi::String memory_scope; + MemoryScope memory_scope; static void RegisterReflection() { namespace refl = tvm::ffi::reflection; @@ -81,7 +86,7 @@ class VDeviceNode : public GlobalInfoNode { */ class VDevice : public GlobalInfo { public: - TVM_DLL explicit VDevice(Target tgt, int dev_id, ffi::String mem_scope); + TVM_DLL explicit VDevice(Target tgt, int dev_id, MemoryScope mem_scope); TVM_FFI_DEFINE_OBJECT_REF_METHODS_NULLABLE(VDevice, GlobalInfo, VDeviceNode); }; diff --git a/include/tvm/relax/attrs/op.h b/include/tvm/relax/attrs/op.h index e7dc64f8005e..54640901ff53 100644 --- a/include/tvm/relax/attrs/op.h +++ b/include/tvm/relax/attrs/op.h @@ -104,7 +104,7 @@ struct ToVDeviceAttrs : public AttrsNodeReflAdapter { struct HintOnDeviceAttrs : public AttrsNodeReflAdapter { int32_t device_type; int32_t index; - ffi::String memory_scope; + MemoryScope memory_scope; static void RegisterReflection() { namespace refl = tvm::ffi::reflection; diff --git a/include/tvm/relax/op_attr_types.h b/include/tvm/relax/op_attr_types.h index 85d0c333c8ae..1fd9b45c323c 100644 --- a/include/tvm/relax/op_attr_types.h +++ b/include/tvm/relax/op_attr_types.h @@ -53,6 +53,12 @@ enum OpPatternKind { kOpaque = 8 }; +/*! + * \brief Packed function implementation for operators. The relax operator will be lowered to + * this packed function call during codegen. + */ +using FCallPacked = ffi::String; + /*! * \brief Infer output struct info given the call * diff --git a/include/tvm/script/printer/doc.h b/include/tvm/script/printer/doc.h index e3d32cb50335..8803e846c08f 100644 --- a/include/tvm/script/printer/doc.h +++ b/include/tvm/script/printer/doc.h @@ -31,6 +31,8 @@ namespace tvm { namespace script { namespace printer { +using AccessPath = ffi::reflection::AccessPath; + // Forward declaration class Doc; @@ -61,7 +63,7 @@ class DocNode : public ffi::Object { * this Doc is generated, in order to position the diagnostic * message. */ - mutable ffi::Array source_paths; + mutable ffi::Array source_paths; static void RegisterReflection() { namespace refl = tvm::ffi::reflection; @@ -258,15 +260,14 @@ class LiteralDocNode : public ExprDocNode { */ class LiteralDoc : public ExprDoc { protected: - explicit LiteralDoc(ffi::Any value, - const ffi::Optional& object_path); + explicit LiteralDoc(ffi::Any value, const ffi::Optional& object_path); public: /*! * \brief Create a LiteralDoc to represent None/null/empty value. * \param p The object path */ - static LiteralDoc None(const ffi::Optional& p) { + static LiteralDoc None(const ffi::Optional& p) { return LiteralDoc(ffi::Any(nullptr), p); } /*! @@ -274,7 +275,7 @@ class LiteralDoc : public ExprDoc { * \param v The integer value. * \param p The object path */ - static LiteralDoc Int(int64_t v, const ffi::Optional& p) { + static LiteralDoc Int(int64_t v, const ffi::Optional& p) { return LiteralDoc(IntImm(DataType::Int(64), v), p); } /*! @@ -282,7 +283,7 @@ class LiteralDoc : public ExprDoc { * \param v The boolean value. * \param p The object path */ - static LiteralDoc Boolean(bool v, const ffi::Optional& p) { + static LiteralDoc Boolean(bool v, const ffi::Optional& p) { return LiteralDoc(IntImm(DataType::Bool(), v), p); } /*! @@ -290,7 +291,7 @@ class LiteralDoc : public ExprDoc { * \param v The float value. * \param p The object path */ - static LiteralDoc Float(double v, const ffi::Optional& p) { + static LiteralDoc Float(double v, const ffi::Optional& p) { return LiteralDoc(FloatImm(DataType::Float(64), v), p); } /*! @@ -298,7 +299,7 @@ class LiteralDoc : public ExprDoc { * \param v The string value. * \param p The object path */ - static LiteralDoc Str(const ffi::String& v, const ffi::Optional& p) { + static LiteralDoc Str(const ffi::String& v, const ffi::Optional& p) { return LiteralDoc(v, p); } /*! @@ -306,8 +307,7 @@ class LiteralDoc : public ExprDoc { * \param v The string value. * \param p The object path */ - static LiteralDoc DataType(const runtime::DataType& v, - const ffi::Optional& p) { + static LiteralDoc DataType(const runtime::DataType& v, const ffi::Optional& p) { std::string dtype = v.is_void() ? "void" : ffi::DLDataTypeToString(v); return LiteralDoc::Str(dtype, p); } @@ -316,7 +316,7 @@ class LiteralDoc : public ExprDoc { * \param v The device. * \param p The object path */ - static LiteralDoc Device(const DLDevice& v, const ffi::Optional& p) { + static LiteralDoc Device(const DLDevice& v, const ffi::Optional& p) { std::ostringstream os; runtime::operator<<(os, v); return LiteralDoc::Str(os.str(), p); diff --git a/include/tvm/script/printer/ir_docsifier.h b/include/tvm/script/printer/ir_docsifier.h index 7c0e082fbc5b..e49d4f8a1cc0 100644 --- a/include/tvm/script/printer/ir_docsifier.h +++ b/include/tvm/script/printer/ir_docsifier.h @@ -36,6 +36,8 @@ namespace tvm { namespace script { namespace printer { +using AccessPath = ffi::reflection::AccessPath; + //////////////////////// Frame //////////////////////// class IRDocsifier; @@ -235,7 +237,7 @@ class IRDocsifierNode : public ffi::Object { * \return The Doc object. */ template - inline TDoc AsDoc(const Any& obj, const ffi::reflection::AccessPath& path) const; + inline TDoc AsDoc(const Any& obj, const AccessPath& path) const; }; /*! @@ -243,7 +245,7 @@ class IRDocsifierNode : public ffi::Object { */ class IRDocsifier : public ffi::ObjectRef { public: - using FType = IRDocsifierFunctor; + using FType = IRDocsifierFunctor; /*! \brief Create a IRDocsifier. */ explicit IRDocsifier(const PrinterConfig& cfg); /*! \brief The registration table for IRDocsifier. */ @@ -271,8 +273,7 @@ inline void FrameNode::ExitWithScope() { } template -inline static void AddDocDecoration(const Doc& d, const ffi::ObjectRef& obj, - const ffi::reflection::AccessPath& path, +inline static void AddDocDecoration(const Doc& d, const ffi::ObjectRef& obj, const AccessPath& path, const PrinterConfig& cfg) { if (cfg->obj_to_annotate.count(obj)) { if (const auto* stmt = d.as()) { @@ -292,7 +293,7 @@ inline static void AddDocDecoration(const Doc& d, const ffi::ObjectRef& obj, } } for (const auto& pair : cfg->path_to_annotate) { - ffi::reflection::AccessPath p = pair.first; + AccessPath p = pair.first; ffi::String attn = pair.second; if (p->IsPrefixOf(path) && path->IsPrefixOf(p)) { if (const auto* stmt = d.as()) { @@ -310,8 +311,7 @@ inline static void AddDocDecoration(const Doc& d, const ffi::ObjectRef& obj, } template -inline TDoc IRDocsifierNode::AsDoc(const Any& value, - const ffi::reflection::AccessPath& path) const { +inline TDoc IRDocsifierNode::AsDoc(const Any& value, const AccessPath& path) const { switch (value.type_index()) { case ffi::TypeIndex::kTVMFFINone: return Downcast(LiteralDoc::None(path)); diff --git a/include/tvm/target/virtual_device.h b/include/tvm/target/virtual_device.h index bf00144b0f66..5ff282adb68b 100644 --- a/include/tvm/target/virtual_device.h +++ b/include/tvm/target/virtual_device.h @@ -35,6 +35,15 @@ namespace tvm { +/*! + * \brief Abstract label for an area of memory. + * + * Currently uninterpreted and arbitrary. Likely to be replaced by a structured representation + * of a memory pool in the future. Please try to use this alias instead of ffi::String to aid future + * code migration. + */ +using MemoryScope = ffi::String; + // NOTE: cannot use enum as they are out of bound of the original enum // and results in an undefined behavior // A 'null' device type, does not correspond to any DLDeviceType enum. @@ -58,7 +67,7 @@ constexpr int kInvalidDeviceType = -1; * See "Virtual Devices" below. * - A \p target (\p Target) describing how to compile code for the intended device. May be null * if unconstrained. - * - A \p memory_scope (\p ffi::String) describing which memory + * - A \p memory_scope (\p MemoryScope, which is currently just \p String) describing which memory * area is to be used to hold data. May be "" if unconstrained. See "Memory Scopes and Devices" * below. * @@ -200,7 +209,7 @@ class VirtualDeviceNode : public AttrsNodeReflAdapter { * * Empty denotes unconstrained. */ - ffi::String memory_scope; + MemoryScope memory_scope; /*! * \brief Returns true if virtual device is 'fully unconstrained', ie no target/device type, @@ -270,7 +279,7 @@ class VirtualDevice : public ffi::ObjectRef { */ TVM_DLL explicit VirtualDevice(int device_type_int = kInvalidDeviceType, int virtual_device_id = -1, Target target = {}, - ffi::String memory_scope = {}); + MemoryScope memory_scope = {}); /*! \brief Returns the unique fully unconstrained \p VirtualDevice. */ static VirtualDevice FullyUnconstrained(); @@ -307,13 +316,13 @@ class VirtualDevice : public ffi::ObjectRef { } /*! \brief Returns the \p VirtualDevice for \p memory_scope alone. */ - static VirtualDevice ForMemoryScope(ffi::String memory_scope) { + static VirtualDevice ForMemoryScope(MemoryScope memory_scope) { return VirtualDevice(kInvalidDeviceType, -1, {}, std::move(memory_scope)); } /*! \brief Returns the \p VirtualDevice for \p device, \p target and \p memory_scope. */ TVM_DLL static VirtualDevice ForDeviceTargetAndMemoryScope(const Device& device, Target target, - ffi::String memory_scope) { + MemoryScope memory_scope) { return VirtualDevice(device.device_type, device.device_id, std::move(target), std::move(memory_scope)); } @@ -349,7 +358,7 @@ class TVM_DLL VirtualDeviceCache { public: /*! \brief Returns the unique \p VirtualDevice representing given fields. */ VirtualDevice Make(int device_type = kInvalidDeviceType, int virtual_device_id = -1, - Target target = {}, ffi::String memory_scope = {}); + Target target = {}, MemoryScope memory_scope = {}); /*! * \brief Returns the unique \p VirtualDevice structurally equal to the given \p virtual_device. diff --git a/include/tvm/tirx/op.h b/include/tvm/tirx/op.h index 3a59ada864e6..c953f12e3870 100644 --- a/include/tvm/tirx/op.h +++ b/include/tvm/tirx/op.h @@ -42,7 +42,7 @@ namespace tvm { #define TVM_TIR_REGISTER_OP(OpName) \ - TVM_REGISTER_OP("tirx." OpName).set_attr("TScriptPrinterName", OpName) + TVM_REGISTER_OP("tirx." OpName).set_attr("TScriptPrinterName", OpName) // Most common operators can be overloaded by argument type(PrimExpr). // So we put them under the root namespace. diff --git a/include/tvm/tirx/op_attr_types.h b/include/tvm/tirx/op_attr_types.h index 2d9aef4b257d..9d0173bfd49f 100644 --- a/include/tvm/tirx/op_attr_types.h +++ b/include/tvm/tirx/op_attr_types.h @@ -36,6 +36,11 @@ namespace tvm { namespace tirx { +/*! + * \brief Global symbol of the op after lowering. + */ +using TGlobalSymbol = ffi::String; + /*! * \brief Whether the op is overloaded for vector form. */ @@ -51,6 +56,11 @@ using FLowerIntrinsic = ffi::TypedFunction; */ using FLegalize = ffi::TypedFunction; +/*! + * \brief The operator's name in TVMScript printer + */ +using TScriptPrinterName = ffi::String; + /*! * \brief Specifies that TVMScript printer prints the dtype as the first/last argument. If not specified, dtype will not be printed. diff --git a/include/tvm/tirx/var.h b/include/tvm/tirx/var.h index 3d84fb00bc0e..c38908d56d7d 100644 --- a/include/tvm/tirx/var.h +++ b/include/tvm/tirx/var.h @@ -173,7 +173,7 @@ class SizeVar : public Var { using ContainerType = SizeVarNode; }; -// NOTE: Region was an alias for ffi::Array; use ffi::Array directly. +using Region = ffi::Array; /*! * \brief Type of iteration variable. diff --git a/src/arith/domain_touched.cc b/src/arith/domain_touched.cc index 12ca88a60ca3..977ea779f450 100644 --- a/src/arith/domain_touched.cc +++ b/src/arith/domain_touched.cc @@ -68,8 +68,8 @@ class BufferTouchedDomain final : public IRVisitorWithAnalyzer { return buffer_access_map_; } - ffi::Array FindUnion(const Buffer& buffer, bool consider_loads, bool consider_stores) { - ffi::Array ret; + Region FindUnion(const Buffer& buffer, bool consider_loads, bool consider_stores) { + Region ret; auto kv = buffer_access_map_.find(buffer.get()); if (kv == buffer_access_map_.end()) { LOG(WARNING) << "[arith::BufferDomainTouched] " @@ -133,8 +133,8 @@ class BufferTouchedDomain final : public IRVisitorWithAnalyzer { std::unordered_map buffer_access_map_; }; -ffi::Array DomainTouched(const Stmt& stmt, const Buffer& buffer, bool consider_loads, - bool consider_stores) { +Region DomainTouched(const Stmt& stmt, const Buffer& buffer, bool consider_loads, + bool consider_stores) { return BufferTouchedDomain(stmt).FindUnion(buffer, consider_loads, consider_stores); } diff --git a/src/ir/global_info.cc b/src/ir/global_info.cc index 4bb37ae1b062..d8bba04c5138 100644 --- a/src/ir/global_info.cc +++ b/src/ir/global_info.cc @@ -39,7 +39,7 @@ TVM_FFI_STATIC_INIT_BLOCK() { }); } -VDevice::VDevice(Target tgt, int dev_id, ffi::String mem_scope) { +VDevice::VDevice(Target tgt, int dev_id, MemoryScope mem_scope) { ffi::ObjectPtr n = ffi::make_object(); n->target = std::move(tgt); n->vdevice_id = std::move(dev_id); @@ -49,7 +49,7 @@ VDevice::VDevice(Target tgt, int dev_id, ffi::String mem_scope) { TVM_FFI_STATIC_INIT_BLOCK() { namespace refl = tvm::ffi::reflection; - refl::GlobalDef().def("ir.VDevice", [](Target tgt, int dev_id, ffi::String mem_scope) { + refl::GlobalDef().def("ir.VDevice", [](Target tgt, int dev_id, MemoryScope mem_scope) { return VDevice(tgt, dev_id, mem_scope); }); } diff --git a/src/relax/backend/vm/codegen_vm.cc b/src/relax/backend/vm/codegen_vm.cc index 108f746f4eed..e9cb175fdcc7 100644 --- a/src/relax/backend/vm/codegen_vm.cc +++ b/src/relax/backend/vm/codegen_vm.cc @@ -405,7 +405,7 @@ class CodeGenVM : public ExprFunctor { } // Emits call to packed function `name` with arguments copied over from `call_node` args - void EmitPackedFuncCall(const Call& call_node, const ffi::String& name, RegName dst_reg) { + void EmitPackedFuncCall(const Call& call_node, const FCallPacked& name, RegName dst_reg) { std::vector args = VisitArray(call_node->args); builder_->EmitCall(name, args, dst_reg); } diff --git a/src/relax/op/op.cc b/src/relax/op/op.cc index 03770e3a1cfc..1d9e48ca6381 100644 --- a/src/relax/op/op.cc +++ b/src/relax/op/op.cc @@ -977,7 +977,7 @@ TVM_REGISTER_OP("relax.print") "The first value is Python-style format string to use to print. The others " "are values to print") .set_attr("FInferStructInfo", ReturnVoidStructInfo) - .set_attr("FCallPacked", "relax.run.print") + .set_attr("FCallPacked", "relax.run.print") .set_attr("FPurity", Bool(false)); Expr MakePrint(ffi::Array vals, StringImm format) { @@ -1023,7 +1023,7 @@ TVM_REGISTER_OP("relax.assert_op") "Python-style format string to use for displaying an error message, if the " "assert fails. The others are used as format arguments if there is an error.") .set_attr("FInferStructInfo", InferAssertStructInfo) - .set_attr("FCallPacked", "relax.run.assert_op") + .set_attr("FCallPacked", "relax.run.assert_op") .set_attr("FPurity", Bool(false)); Expr MakeAssertOp(Expr condition, ffi::Array vals, StringImm format) { @@ -1204,7 +1204,7 @@ TVM_REGISTER_OP("relax.shape_to_tensor") .set_num_inputs(1) .add_argument("input", "Expr", "The input expression") .set_attr("FInferStructInfo", ReturnShapeToTensorStructInfo) - .set_attr("FCallPacked", "relax.run.shape_to_tensor") + .set_attr("FCallPacked", "relax.run.shape_to_tensor") .set_attr("FPurity", Bool(true)); Expr MakeShapeToTensor(Expr expr) { diff --git a/src/relax/op/tensor/set.cc b/src/relax/op/tensor/set.cc index b1e23edb7340..183c254fb8fd 100644 --- a/src/relax/op/tensor/set.cc +++ b/src/relax/op/tensor/set.cc @@ -166,7 +166,7 @@ TVM_REGISTER_OP("relax.unique") "flattened input " "are returned.") .set_attr("FInferStructInfo", InferStructInfoUnique) - .set_attr("FCallPacked", "relax.run.unique") + .set_attr("FCallPacked", "relax.run.unique") .set_attr("FPurity", Bool(true)); /* relax.nonzero */ @@ -189,7 +189,7 @@ TVM_REGISTER_OP("relax.nonzero") .set_num_inputs(1) .add_argument("x", "Tensor", "The input tensor") .set_attr("FInferStructInfo", InferStructInfoNonzero) - .set_attr("FCallPacked", "relax.run.nonzero") + .set_attr("FCallPacked", "relax.run.nonzero") .set_attr("FPurity", Bool(true)); } // namespace relax diff --git a/src/relax/script/printer/binding.cc b/src/relax/script/printer/binding.cc index da8e6ae8de01..d756a82a0e18 100644 --- a/src/relax/script/printer/binding.cc +++ b/src/relax/script/printer/binding.cc @@ -24,7 +24,7 @@ namespace tvm { namespace script { namespace printer { -IfDoc PrintIfExpr(const relax::If& n, const ffi::reflection::AccessPath& n_p, +IfDoc PrintIfExpr(const relax::If& n, const AccessPath& n_p, const IRDocsifier& d, // const ffi::Optional& var, const ffi::Optional& ann) { using relax::SeqExpr; @@ -44,7 +44,7 @@ IfDoc PrintIfExpr(const relax::If& n, const ffi::reflection::AccessPath& n_p, TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) .set_dispatch( - "", [](relax::MatchCast n, ffi::reflection::AccessPath n_p, IRDocsifier d) -> Doc { + "", [](relax::MatchCast n, AccessPath n_p, IRDocsifier d) -> Doc { using relax::StructInfo; using relax::MatchStructInfo; ffi::Optional ann = std::nullopt; @@ -60,7 +60,7 @@ TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) .set_dispatch( // - "", [](relax::VarBinding n, ffi::reflection::AccessPath n_p, IRDocsifier d) -> Doc { + "", [](relax::VarBinding n, AccessPath n_p, IRDocsifier d) -> Doc { if (const auto if_ = n->value.as()) { ffi::Optional ann = StructInfoAsAnn(n->var, n_p->Attr("var"), d, n->value); ExprDoc lhs = DefineVar(n->var, d->frames.back(), d); @@ -85,11 +85,9 @@ TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) }); TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) - .set_dispatch("", - [](relax::If n, ffi::reflection::AccessPath n_p, - IRDocsifier d) -> Doc { - return PrintIfExpr(n, n_p, d, std::nullopt, std::nullopt); - }); + .set_dispatch("", [](relax::If n, AccessPath n_p, IRDocsifier d) -> Doc { + return PrintIfExpr(n, n_p, d, std::nullopt, std::nullopt); + }); TVM_REGISTER_SCRIPT_AS_REPR(relax::MatchCastNode, ReprPrintRelax); TVM_REGISTER_SCRIPT_AS_REPR(relax::VarBindingNode, ReprPrintRelax); diff --git a/src/relax/script/printer/call.cc b/src/relax/script/printer/call.cc index af4cb54f6848..262be66e924c 100644 --- a/src/relax/script/printer/call.cc +++ b/src/relax/script/printer/call.cc @@ -29,8 +29,8 @@ namespace printer { class AttrPrinter { public: - explicit AttrPrinter(ffi::reflection::AccessPath p, const IRDocsifier& d, - ffi::Array* keys, ffi::Array* values) + explicit AttrPrinter(AccessPath p, const IRDocsifier& d, ffi::Array* keys, + ffi::Array* values) : p(std::move(p)), d(d), keys(keys), values(values) {} void operator()(const tvm::Attrs& attrs) { @@ -54,14 +54,13 @@ class AttrPrinter { } } - ffi::reflection::AccessPath p; + AccessPath p; const IRDocsifier& d; ffi::Array* keys; ffi::Array* values; }; -ExprDoc PrintCallee(const relax::Expr& n, const ffi::reflection::AccessPath& n_p, - const IRDocsifier& d) { +ExprDoc PrintCallee(const relax::Expr& n, const AccessPath& n_p, const IRDocsifier& d) { // TODO(@junrushao): handle callee better if (const auto* ext = n.as()) { return LiteralDoc::Str(ext->global_symbol, n_p); @@ -70,8 +69,7 @@ ExprDoc PrintCallee(const relax::Expr& n, const ffi::reflection::AccessPath& n_p } } -ffi::Optional PrintCallTIRDPSPacked(const relax::Call& n, - const ffi::reflection::AccessPath& n_p, +ffi::Optional PrintCallTIRDPSPacked(const relax::Call& n, const AccessPath& n_p, const IRDocsifier& d) { static const Op& call_tir_op = Op::Get("relax.call_tir"); static const Op& call_tir_inplace_op = Op::Get("relax.call_tir_inplace"); @@ -94,12 +92,12 @@ ffi::Optional PrintCallTIRDPSPacked(const relax::Call& n, args.push_back(d->AsDoc(n->args[1], n_p->Attr("args")->ArrayItem(1))); // Step 3. Print n->sinfo_args, the output struct info relax::StructInfo o_sinfo = n->sinfo_args[0]; - ffi::reflection::AccessPath o_sinfo_p = n_p->Attr("sinfo_args")->ArrayItem(0); + AccessPath o_sinfo_p = n_p->Attr("sinfo_args")->ArrayItem(0); bool is_dtensor = false; kwargs_keys.push_back("out_sinfo"); if (const auto* o = o_sinfo.as()) { ffi::Array fields; - ffi::reflection::AccessPath fields_p = o_sinfo_p->Attr("fields"); + AccessPath fields_p = o_sinfo_p->Attr("fields"); for (int i = 0, l = o->fields.size(); i < l; ++i) { if (o->fields[i].as()) { is_dtensor = true; @@ -162,7 +160,7 @@ ffi::Optional PrintCallTIRDPSPacked(const relax::Call& n, } } -ffi::Optional PrintAssertOp(const relax::Call& n, const ffi::reflection::AccessPath& n_p, +ffi::Optional PrintAssertOp(const relax::Call& n, const AccessPath& n_p, const IRDocsifier& d) { static const Op& assert_op = Op::Get("relax.assert_op"); if (!n->op.same_as(assert_op)) { @@ -182,8 +180,7 @@ ffi::Optional PrintAssertOp(const relax::Call& n, const ffi::reflection return Relax(d, "assert_op")->Call(args, {"format"}, {second_arg}); } -ffi::Optional PrintHintOnDevice(const relax::Call& n, - const ffi::reflection::AccessPath& n_p, +ffi::Optional PrintHintOnDevice(const relax::Call& n, const AccessPath& n_p, const IRDocsifier& d) { static const Op& hint_on_device_op = Op::Get("relax.hint_on_device"); if (!n->op.same_as(hint_on_device_op)) { @@ -206,7 +203,7 @@ ffi::Optional PrintHintOnDevice(const relax::Call& n, return Relax(d, "hint_on_device")->Call(args); } -ffi::Optional PrintToVDevice(const relax::Call& n, const ffi::reflection::AccessPath& n_p, +ffi::Optional PrintToVDevice(const relax::Call& n, const AccessPath& n_p, const IRDocsifier& d) { static const Op& to_vdevice_op = Op::Get("relax.to_vdevice"); if (!n->op.same_as(to_vdevice_op)) { @@ -230,7 +227,7 @@ ffi::Optional PrintToVDevice(const relax::Call& n, const ffi::reflectio return Relax(d, "to_vdevice")->Call(args, kwargs_keys, kwargs_values); } -ffi::Optional PrintRelaxPrint(const relax::Call& n, const ffi::reflection::AccessPath& n_p, +ffi::Optional PrintRelaxPrint(const relax::Call& n, const AccessPath& n_p, const IRDocsifier& d) { static const Op& print_op = Op::Get("relax.print"); if (!n->op.same_as(print_op)) { @@ -251,7 +248,7 @@ ffi::Optional PrintRelaxPrint(const relax::Call& n, const ffi::reflecti TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) .set_dispatch( // - "", [](relax::Call n, ffi::reflection::AccessPath n_p, IRDocsifier d) -> Doc { + "", [](relax::Call n, AccessPath n_p, IRDocsifier d) -> Doc { // Special case: call_tir, call_dps_packed, call_tir_with_grad if (ffi::Optional doc = PrintCallTIRDPSPacked(n, n_p, d)) { return doc.value(); @@ -325,7 +322,7 @@ TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) } // Step 4. Print type_args if (n->sinfo_args.size() > 0) { - ffi::reflection::AccessPath sinfo_args_p = n_p->Attr("sinfo_args"); + AccessPath sinfo_args_p = n_p->Attr("sinfo_args"); ffi::Array sinfo_args; for (int i = 0, l = n->sinfo_args.size(); i < l; ++i) { sinfo_args.push_back(d->AsDoc(n->sinfo_args[i], sinfo_args_p->ArrayItem(i))); diff --git a/src/relax/script/printer/distributed.cc b/src/relax/script/printer/distributed.cc index 98a96b84ebbe..0a67b55af89f 100644 --- a/src/relax/script/printer/distributed.cc +++ b/src/relax/script/printer/distributed.cc @@ -30,17 +30,14 @@ namespace printer { // distributed::Placement TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) .set_dispatch("", - [](relax::distributed::Placement n, - ffi::reflection::AccessPath n_p, + [](relax::distributed::Placement n, AccessPath n_p, IRDocsifier d) -> Doc { return d->AsDoc(n->ToString(), n_p); }); TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) .set_dispatch( - "", - [](relax::distributed::DTensorStructInfo n, ffi::reflection::AccessPath n_p, - IRDocsifier d) -> Doc { + "", [](relax::distributed::DTensorStructInfo n, AccessPath n_p, IRDocsifier d) -> Doc { ffi::Array args; ffi::Array kwargs_keys; ffi::Array kwargs_values; @@ -49,7 +46,7 @@ TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) // Need to dig into ShapeExpr to preserve the `R.shape` prefix if (const auto* shape = n->tensor_sinfo->shape.value().as()) { auto shape_expr = ffi::GetRef(shape); - ffi::reflection::AccessPath shape_p = n_p->Attr("shape")->Attr("values"); + AccessPath shape_p = n_p->Attr("shape")->Attr("values"); ffi::Array shape_docs; for (int i = 0, ndim = shape_expr->values.size(); i < ndim; ++i) { shape_docs.push_back( @@ -94,9 +91,7 @@ TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) .set_dispatch( - "", - [](relax::distributed::DeviceMesh n, ffi::reflection::AccessPath n_p, - IRDocsifier d) -> Doc { + "", [](relax::distributed::DeviceMesh n, AccessPath n_p, IRDocsifier d) -> Doc { bool has_relax_frame = false; const IRFrameNode* f = nullptr; for (const Frame& frame : d->frames) { diff --git a/src/relax/script/printer/expr.cc b/src/relax/script/printer/expr.cc index b6d750bc2df3..c8a813b8d5ab 100644 --- a/src/relax/script/printer/expr.cc +++ b/src/relax/script/printer/expr.cc @@ -30,32 +30,32 @@ namespace printer { TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) .set_dispatch( // - "", [](relax::PrimValue n, ffi::reflection::AccessPath n_p, IRDocsifier d) -> Doc { + "", [](relax::PrimValue n, AccessPath n_p, IRDocsifier d) -> Doc { // TODO(@junrushao): float numbers return Relax(d, "prim_value")->Call({d->AsDoc(n->value, n_p->Attr("value"))}); }); TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) .set_dispatch( // - "", [](relax::StringImm n, ffi::reflection::AccessPath n_p, IRDocsifier d) -> Doc { + "", [](relax::StringImm n, AccessPath n_p, IRDocsifier d) -> Doc { return Relax(d, "str")->Call({LiteralDoc::Str(n->value, n_p->Attr("value"))}); }); TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) .set_dispatch( // - "", [](relax::DataTypeImm n, ffi::reflection::AccessPath n_p, IRDocsifier d) -> Doc { + "", [](relax::DataTypeImm n, AccessPath n_p, IRDocsifier d) -> Doc { return Relax(d, "dtype")->Call({LiteralDoc::DataType(n->value, n_p->Attr("value"))}); }); TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) .set_dispatch( // - "", [](relax::Tuple n, ffi::reflection::AccessPath n_p, IRDocsifier d) -> Doc { + "", [](relax::Tuple n, AccessPath n_p, IRDocsifier d) -> Doc { // TODO(@junrushao): revisit tuple printing if (n->fields.empty()) { return Relax(d, "tuple")->Call({}); } ffi::Array fields_doc; - ffi::reflection::AccessPath fields_p = n_p->Attr("fields"); + AccessPath fields_p = n_p->Attr("fields"); for (int i = 0, l = n->fields.size(); i < l; ++i) { fields_doc.push_back(d->AsDoc(n->fields[i], fields_p->ArrayItem(i))); } @@ -64,24 +64,23 @@ TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) .set_dispatch( // - "", [](relax::TupleGetItem n, ffi::reflection::AccessPath n_p, IRDocsifier d) -> Doc { + "", [](relax::TupleGetItem n, AccessPath n_p, IRDocsifier d) -> Doc { ExprDoc idx = LiteralDoc::Int(n->index, n_p->Attr("index")); return d->AsDoc(n->tuple, n_p->Attr("tuple"))[{idx}]; }); TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) .set_dispatch( // - "", [](relax::ShapeExpr n, ffi::reflection::AccessPath n_p, IRDocsifier d) -> Doc { + "", [](relax::ShapeExpr n, AccessPath n_p, IRDocsifier d) -> Doc { ffi::Array values_doc; - ffi::reflection::AccessPath values_p = n_p->Attr("values"); + AccessPath values_p = n_p->Attr("values"); for (int i = 0, l = n->values.size(); i < l; ++i) { values_doc.push_back(PrintShapeVar(n->values[i], values_p->ArrayItem(i), d)); } return Relax(d, "shape")->Call({ListDoc(values_doc)}); }); -ffi::Optional SpecialScalar(const runtime::Tensor& n, - const ffi::reflection::AccessPath& p) { +ffi::Optional SpecialScalar(const runtime::Tensor& n, const AccessPath& p) { DataType dtype = n.DataType(); const void* data = n->data; if (n->ndim != 0 || n->device.device_type != kDLCPU) { @@ -136,7 +135,7 @@ ffi::Optional SpecialScalar(const runtime::Tensor& n, TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) .set_dispatch( // - "", [](relax::Constant n, ffi::reflection::AccessPath n_p, IRDocsifier d) -> Doc { + "", [](relax::Constant n, AccessPath n_p, IRDocsifier d) -> Doc { if (ffi::Optional s = SpecialScalar(n->data, n_p->Attr("data"))) { if (n->struct_info_.as()) { ExprDoc ann = d->AsDoc(n->struct_info_, n_p->Attr("struct_info_")); @@ -151,7 +150,7 @@ TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) return d->AddMetadata(n); }); -Doc PrintRelaxVar(relax::Var n, ffi::reflection::AccessPath p, IRDocsifier d) { +Doc PrintRelaxVar(relax::Var n, AccessPath p, IRDocsifier d) { if (!d->IsVarDefined(n)) { ExprDoc ann = d->AsDoc(n->struct_info_, p->Attr("struct_info_")); Frame f = d->frames.back(); diff --git a/src/relax/script/printer/function.cc b/src/relax/script/printer/function.cc index bc4309a7a0d7..e30a2b0bf432 100644 --- a/src/relax/script/printer/function.cc +++ b/src/relax/script/printer/function.cc @@ -49,99 +49,94 @@ bool AtTopLevelFunction(const IRDocsifier& d) { TVM_FFI_STATIC_INIT_BLOCK() { RelaxFrameNode::RegisterReflection(); } TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) - .set_dispatch( - "", [](relax::Function n, ffi::reflection::AccessPath n_p, IRDocsifier d) -> Doc { - std::unordered_set func_vars; - With f(d); + .set_dispatch("", [](relax::Function n, AccessPath n_p, IRDocsifier d) -> Doc { + std::unordered_set func_vars; + With f(d); - IdDoc func_name(""); - // if we are binding a local definition, then calling d->Define - // will result in a repeated definition and an incorrect displayed name - if (ffi::Optional name = GetBindingName(d)) { - func_name = IdDoc(name.value()); - } else { - func_name = IdDoc(FindFunctionName(d, n).value_or("main")); - } - (*f)->AddDispatchToken(d, "relax"); - (*f)->is_func = true; - (*f)->func_vars = &func_vars; - // Step 1. Print the return type - ffi::Optional ret_type = std::nullopt; - if (const auto& func_sinfo = relax::MatchStructInfo(n)) { - ret_type = d->AsDoc(func_sinfo.value()->ret, // - n_p->Attr("struct_info_")->Attr("ret")); - } - // Step 2. Print params - ffi::Array params; - { - ffi::reflection::AccessPath params_p = n_p->Attr("params"); - for (int i = 0, l = n->params.size(); i < l; ++i) { - params.push_back(AssignDoc( - /*lhs=*/DefineVar(n->params[i], *f, d), - /*rhs=*/std::nullopt, - StructInfoAsAnn(n->params[i], params_p->ArrayItem(i), d, std::nullopt))); - } - } - // Step 3. Clean up func variables - (*f)->func_vars = nullptr; - // Step 4. Print attributes - if (n->attrs.defined() && !n->attrs->dict.empty()) { - // If the function is a global function and has a global symbol, - // then don't print the global symbol (it will be implicit from not being private). - // For a function without an IR module whose global symbol - // doesn't match the function name, we should still print the global symbol attribute. - if (AtTopLevelFunction(d) && n->attrs->dict.count(tvm::attr::kGlobalSymbol) && - Downcast(n->attrs->dict.at(tvm::attr::kGlobalSymbol)) == - func_name->name) { - ffi::Map new_attrs; - for (auto kv : n->attrs->dict) { - if (kv.first != tvm::attr::kGlobalSymbol) { - new_attrs.Set(kv.first, kv.second); - } - } - if (!new_attrs.empty()) { - (*f)->stmts.push_back(ExprStmtDoc( - Relax(d, "func_attr") // - ->Call({d->AsDoc(DictAttrs(new_attrs), n_p->Attr("attrs"))}))); - } - } else { - (*f)->stmts.push_back( - ExprStmtDoc(Relax(d, "func_attr") // - ->Call({d->AsDoc(n->attrs, n_p->Attr("attrs"))}))); + IdDoc func_name(""); + // if we are binding a local definition, then calling d->Define + // will result in a repeated definition and an incorrect displayed name + if (ffi::Optional name = GetBindingName(d)) { + func_name = IdDoc(name.value()); + } else { + func_name = IdDoc(FindFunctionName(d, n).value_or("main")); + } + (*f)->AddDispatchToken(d, "relax"); + (*f)->is_func = true; + (*f)->func_vars = &func_vars; + // Step 1. Print the return type + ffi::Optional ret_type = std::nullopt; + if (const auto& func_sinfo = relax::MatchStructInfo(n)) { + ret_type = d->AsDoc(func_sinfo.value()->ret, // + n_p->Attr("struct_info_")->Attr("ret")); + } + // Step 2. Print params + ffi::Array params; + { + AccessPath params_p = n_p->Attr("params"); + for (int i = 0, l = n->params.size(); i < l; ++i) { + params.push_back(AssignDoc( + /*lhs=*/DefineVar(n->params[i], *f, d), + /*rhs=*/std::nullopt, + StructInfoAsAnn(n->params[i], params_p->ArrayItem(i), d, std::nullopt))); + } + } + // Step 3. Clean up func variables + (*f)->func_vars = nullptr; + // Step 4. Print attributes + if (n->attrs.defined() && !n->attrs->dict.empty()) { + // If the function is a global function and has a global symbol, + // then don't print the global symbol (it will be implicit from not being private). + // For a function without an IR module whose global symbol + // doesn't match the function name, we should still print the global symbol attribute. + if (AtTopLevelFunction(d) && n->attrs->dict.count(tvm::attr::kGlobalSymbol) && + Downcast(n->attrs->dict.at(tvm::attr::kGlobalSymbol)) == func_name->name) { + ffi::Map new_attrs; + for (auto kv : n->attrs->dict) { + if (kv.first != tvm::attr::kGlobalSymbol) { + new_attrs.Set(kv.first, kv.second); } } - // Step 5. Prepare the decorator (include purity if it's impure) - ExprDoc decorator = Relax(d, "function"); - ffi::Array pos_args = {}; - ffi::Array dec_keys; - ffi::Array dec_values; - if (!n->is_pure) { - dec_keys.push_back("pure"); - dec_values.push_back( - LiteralDoc::Boolean(false, ffi::Optional())); - } - // if the function is global or is not in a module and does not have a global symbol, - // indicate that it's private - if (AtTopLevelFunction(d) && - (!n->attrs.defined() || !n->attrs->dict.count(tvm::attr::kGlobalSymbol))) { - dec_keys.push_back("private"); - dec_values.push_back( - LiteralDoc::Boolean(true, ffi::Optional())); - } - if (dec_keys.size()) { - decorator = decorator->Call(pos_args, dec_keys, dec_values); + if (!new_attrs.empty()) { + (*f)->stmts.push_back(ExprStmtDoc( + Relax(d, "func_attr") // + ->Call({d->AsDoc(DictAttrs(new_attrs), n_p->Attr("attrs"))}))); } + } else { + (*f)->stmts.push_back( + ExprStmtDoc(Relax(d, "func_attr") // + ->Call({d->AsDoc(n->attrs, n_p->Attr("attrs"))}))); + } + } + // Step 5. Prepare the decorator (include purity if it's impure) + ExprDoc decorator = Relax(d, "function"); + ffi::Array pos_args = {}; + ffi::Array dec_keys; + ffi::Array dec_values; + if (!n->is_pure) { + dec_keys.push_back("pure"); + dec_values.push_back(LiteralDoc::Boolean(false, ffi::Optional())); + } + // if the function is global or is not in a module and does not have a global symbol, + // indicate that it's private + if (AtTopLevelFunction(d) && + (!n->attrs.defined() || !n->attrs->dict.count(tvm::attr::kGlobalSymbol))) { + dec_keys.push_back("private"); + dec_values.push_back(LiteralDoc::Boolean(true, ffi::Optional())); + } + if (dec_keys.size()) { + decorator = decorator->Call(pos_args, dec_keys, dec_values); + } - // Step 6. Print body - ffi::Array body = PrintSeqExpr(n->body, n_p->Attr("body"), d, /*use_ret=*/true); - (*f)->stmts.insert((*f)->stmts.end(), body.begin(), body.end()); - return HeaderWrapper(d, - FunctionDoc(func_name, params, {decorator}, ret_type, (*f)->stmts)); - }); + // Step 6. Print body + ffi::Array body = PrintSeqExpr(n->body, n_p->Attr("body"), d, /*use_ret=*/true); + (*f)->stmts.insert((*f)->stmts.end(), body.begin(), body.end()); + return HeaderWrapper(d, FunctionDoc(func_name, params, {decorator}, ret_type, (*f)->stmts)); + }); TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) .set_dispatch( // - "", [](relax::ExternFunc n, ffi::reflection::AccessPath n_p, IRDocsifier d) -> Doc { + "", [](relax::ExternFunc n, AccessPath n_p, IRDocsifier d) -> Doc { ffi::Array args; args.push_back(LiteralDoc::Str(n->global_symbol, n_p->Attr("global_symbol"))); if (!HasDefaultExternFuncStructInfo(n)) { diff --git a/src/relax/script/printer/region.cc b/src/relax/script/printer/region.cc index 561877db8032..f5b50e66c97f 100644 --- a/src/relax/script/printer/region.cc +++ b/src/relax/script/printer/region.cc @@ -24,11 +24,11 @@ namespace tvm { namespace script { namespace printer { -ffi::Array PrintSeqExpr(const relax::SeqExpr& n, const ffi::reflection::AccessPath& n_p, +ffi::Array PrintSeqExpr(const relax::SeqExpr& n, const AccessPath& n_p, const IRDocsifier& d, bool use_ret) { With f(d); const ffi::Array& blocks = n->blocks; - ffi::reflection::AccessPath blocks_p = n_p->Attr("blocks"); + AccessPath blocks_p = n_p->Attr("blocks"); ffi::Array* stmts = &(*f)->stmts; for (int i = 0, l = blocks.size(); i < l; ++i) { Doc block = d->AsDoc(blocks[i], blocks_p->ArrayItem(i)); @@ -50,21 +50,19 @@ ffi::Array PrintSeqExpr(const relax::SeqExpr& n, const ffi::reflection: } TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) - .set_dispatch("", - [](relax::SeqExpr n, ffi::reflection::AccessPath n_p, - IRDocsifier d) -> Doc { - return StmtBlockDoc(PrintSeqExpr(n, n_p, d, false)); - }); + .set_dispatch("", [](relax::SeqExpr n, AccessPath n_p, IRDocsifier d) -> Doc { + return StmtBlockDoc(PrintSeqExpr(n, n_p, d, false)); + }); -ffi::Array PrintBindingBlock(const relax::BindingBlock& n, - const ffi::reflection::AccessPath& n_p, const IRDocsifier& d, +ffi::Array PrintBindingBlock(const relax::BindingBlock& n, const AccessPath& n_p, + const IRDocsifier& d, ffi::Array* non_dataflow_vars) { const ffi::Array& bindings = n->bindings; - ffi::reflection::AccessPath bindings_p = n_p->Attr("bindings"); + AccessPath bindings_p = n_p->Attr("bindings"); ffi::Array stmts; for (int i = 0, l = bindings.size(); i < l; ++i) { const relax::Binding& binding = bindings[i]; - ffi::reflection::AccessPath binding_p = bindings_p->ArrayItem(i); + AccessPath binding_p = bindings_p->ArrayItem(i); TVM_FFI_ICHECK(binding->var.defined()); Doc binding_doc = d->AsDoc(binding, binding_p); if (const auto* stmt = binding_doc.as()) { @@ -83,13 +81,13 @@ ffi::Array PrintBindingBlock(const relax::BindingBlock& n, TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) .set_dispatch( // - "", [](relax::BindingBlock n, ffi::reflection::AccessPath n_p, IRDocsifier d) -> Doc { + "", [](relax::BindingBlock n, AccessPath n_p, IRDocsifier d) -> Doc { return StmtBlockDoc(PrintBindingBlock(n, n_p, d, nullptr)); }); TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) .set_dispatch( // - "", [](relax::DataflowBlock n, ffi::reflection::AccessPath n_p, IRDocsifier d) -> Doc { + "", [](relax::DataflowBlock n, AccessPath n_p, IRDocsifier d) -> Doc { ffi::Array non_dataflow_vars; ffi::Array stmts = PrintBindingBlock(n, n_p, d, &non_dataflow_vars); stmts.push_back(ExprStmtDoc(Relax(d, "output")->Call(non_dataflow_vars))); diff --git a/src/relax/script/printer/struct_info.cc b/src/relax/script/printer/struct_info.cc index f4f054b24c6a..1019cfa7e9bb 100644 --- a/src/relax/script/printer/struct_info.cc +++ b/src/relax/script/printer/struct_info.cc @@ -27,12 +27,11 @@ namespace printer { TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) .set_dispatch( // - "", [](relax::ObjectStructInfo n, ffi::reflection::AccessPath n_p, IRDocsifier d) -> Doc { + "", [](relax::ObjectStructInfo n, AccessPath n_p, IRDocsifier d) -> Doc { return Relax(d, "Object"); }); -ExprDoc PrintShapeVar(const PrimExpr& e, const ffi::reflection::AccessPath& e_p, - const IRDocsifier& d) { +ExprDoc PrintShapeVar(const PrimExpr& e, const AccessPath& e_p, const IRDocsifier& d) { ExprDoc expr_doc = d->AsDoc(e, e_p); // Step 1. Find if `func_vars` are being collected const RelaxFrameNode* f = nullptr; @@ -64,7 +63,7 @@ ExprDoc PrintShapeVar(const PrimExpr& e, const ffi::reflection::AccessPath& e_p, TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) .set_dispatch( - "", [](relax::PrimStructInfo n, ffi::reflection::AccessPath n_p, IRDocsifier d) -> Doc { + "", [](relax::PrimStructInfo n, AccessPath n_p, IRDocsifier d) -> Doc { ffi::Array args; ffi::Array kwargs_keys; ffi::Array kwargs_values; @@ -81,10 +80,10 @@ TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) .set_dispatch( - "", [](relax::ShapeStructInfo n, ffi::reflection::AccessPath n_p, IRDocsifier d) -> Doc { + "", [](relax::ShapeStructInfo n, AccessPath n_p, IRDocsifier d) -> Doc { if (n->values.defined()) { ffi::Array shape = n->values.value(); - ffi::reflection::AccessPath shape_p = n_p->Attr("values"); + AccessPath shape_p = n_p->Attr("values"); ffi::Array shape_docs; for (int i = 0, ndim = shape.size(); i < ndim; ++i) { shape_docs.push_back(PrintShapeVar(shape[i], shape_p->ArrayItem(i), d)); @@ -97,7 +96,7 @@ TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) .set_dispatch( // - "", [](relax::TensorStructInfo n, ffi::reflection::AccessPath n_p, IRDocsifier d) -> Doc { + "", [](relax::TensorStructInfo n, AccessPath n_p, IRDocsifier d) -> Doc { ffi::Array args; ffi::Array kwargs_keys; ffi::Array kwargs_values; @@ -105,7 +104,7 @@ TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) // Need to dig into ShapeExpr to preserve the `R.shape` prefix if (const auto* shape = n->shape.value().as()) { auto shape_expr = ffi::GetRef(shape); - ffi::reflection::AccessPath shape_p = n_p->Attr("shape")->Attr("values"); + AccessPath shape_p = n_p->Attr("shape")->Attr("values"); ffi::Array shape_docs; for (int i = 0, ndim = shape_expr->values.size(); i < ndim; ++i) { shape_docs.push_back( @@ -140,12 +139,12 @@ TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) .set_dispatch( // - "", [](relax::TupleStructInfo n, ffi::reflection::AccessPath n_p, IRDocsifier d) -> Doc { + "", [](relax::TupleStructInfo n, AccessPath n_p, IRDocsifier d) -> Doc { if (n->fields.empty()) { return Relax(d, "Tuple"); } ffi::Array fields_doc; - ffi::reflection::AccessPath fields_p = n_p->Attr("fields"); + AccessPath fields_p = n_p->Attr("fields"); for (int i = 0, l = n->fields.size(); i < l; ++i) { fields_doc.push_back(d->AsDoc(n->fields[i], fields_p->ArrayItem(i))); } @@ -154,7 +153,7 @@ TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) .set_dispatch( // - "", [](relax::FuncStructInfo n, ffi::reflection::AccessPath n_p, IRDocsifier d) -> Doc { + "", [](relax::FuncStructInfo n, AccessPath n_p, IRDocsifier d) -> Doc { auto ret_doc = d->AsDoc(n->ret, n_p->Attr("ret")); auto purity_doc = LiteralDoc::Boolean(n->purity, n_p->Attr("purity")); @@ -180,7 +179,7 @@ TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) // TODO(@junrushao): track symbolic shape relation ffi::Array params_doc; ffi::Array params = n->params.value(); - ffi::reflection::AccessPath params_p = n_p->Attr("params"); + AccessPath params_p = n_p->Attr("params"); for (int i = 0, n_params = params.size(); i < n_params; ++i) { params_doc.push_back(d->AsDoc(params[i], params_p->ArrayItem(i))); } diff --git a/src/relax/script/printer/tir.cc b/src/relax/script/printer/tir.cc index 345ab9b1dcbf..e0742f8edd44 100644 --- a/src/relax/script/printer/tir.cc +++ b/src/relax/script/printer/tir.cc @@ -42,7 +42,7 @@ RelaxFrameNode* GetRelaxFrame(IRDocsifier d) { return f; } -Doc PrintTIRVar(tirx::Var n, ffi::reflection::AccessPath n_p, IRDocsifier d) { +Doc PrintTIRVar(tirx::Var n, AccessPath n_p, IRDocsifier d) { TVM_FFI_CHECK(n->dtype.is_scalar(), TypeError) << "Relax only uses scalar TIR variables," << "but received TIR variable " << n << " with dtype " << n->dtype; @@ -74,8 +74,8 @@ TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable).set_dispatch("relax", Prin TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable).set_dispatch("relax", PrintTIRVar); TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) - .set_dispatch( // - "relax", [](tvm::IntImm n, ffi::reflection::AccessPath n_p, IRDocsifier d) -> Doc { // + .set_dispatch( // + "relax", [](tvm::IntImm n, AccessPath n_p, IRDocsifier d) -> Doc { // // TODO(@junrushao): support non-int64 cases if (n->dtype.is_bool()) { return LiteralDoc::Boolean(n->value, n_p); @@ -85,8 +85,8 @@ TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) }); TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) - .set_dispatch( // - "relax", [](tvm::GlobalVar n, ffi::reflection::AccessPath n_p, IRDocsifier d) -> Doc { // + .set_dispatch( // + "relax", [](tvm::GlobalVar n, AccessPath n_p, IRDocsifier d) -> Doc { // if (ffi::Optional doc = d->GetVarDoc(n)) { return doc.value(); } else { @@ -97,8 +97,8 @@ TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) }); TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) - .set_dispatch( // - "relax", [](tvm::IRModule mod, ffi::reflection::AccessPath n_p, IRDocsifier d) -> Doc { // + .set_dispatch( // + "relax", [](tvm::IRModule mod, AccessPath n_p, IRDocsifier d) -> Doc { // ffi::Optional doc = d->GetVarDoc(mod); TVM_FFI_ICHECK(doc) << "Unable to print IRModule before definition in Relax."; if (d->cfg->module_alias.empty()) { @@ -118,14 +118,13 @@ TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) }); TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) - .set_dispatch("relax", - [](Range range, ffi::reflection::AccessPath p, IRDocsifier d) -> Doc { - return Relax(d, "Range") - ->Call({ - d->AsDoc(range->min, p->Attr("min")), - d->AsDoc(range->extent + range->min, p->Attr("extent")), - }); - }); + .set_dispatch("relax", [](Range range, AccessPath p, IRDocsifier d) -> Doc { + return Relax(d, "Range") + ->Call({ + d->AsDoc(range->min, p->Attr("min")), + d->AsDoc(range->extent + range->min, p->Attr("extent")), + }); + }); } // namespace printer } // namespace script diff --git a/src/relax/script/printer/type.cc b/src/relax/script/printer/type.cc index 1c01972c517d..f5cbfcb16615 100644 --- a/src/relax/script/printer/type.cc +++ b/src/relax/script/printer/type.cc @@ -26,20 +26,20 @@ namespace printer { TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) .set_dispatch( // - "", [](relax::ShapeType n, ffi::reflection::AccessPath n_p, IRDocsifier d) -> Doc { + "", [](relax::ShapeType n, AccessPath n_p, IRDocsifier d) -> Doc { return Relax(d, "Shape") ->Call({}, {"ndim"}, {LiteralDoc::Int(n->ndim, n_p->Attr("ndim"))}); }); TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) .set_dispatch( // - "", [](relax::ObjectType n, ffi::reflection::AccessPath n_p, IRDocsifier d) -> Doc { + "", [](relax::ObjectType n, AccessPath n_p, IRDocsifier d) -> Doc { return Relax(d, "Object"); }); TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) .set_dispatch( // - "", [](relax::TensorType n, ffi::reflection::AccessPath n_p, IRDocsifier d) -> Doc { + "", [](relax::TensorType n, AccessPath n_p, IRDocsifier d) -> Doc { return Relax(d, "Tensor") ->Call({}, {"ndim", "dtype"}, {LiteralDoc::Int(n->ndim, n_p->Attr("ndim")), @@ -48,18 +48,18 @@ TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) .set_dispatch( // - "", [](relax::PackedFuncType n, ffi::reflection::AccessPath n_p, IRDocsifier d) -> Doc { + "", [](relax::PackedFuncType n, AccessPath n_p, IRDocsifier d) -> Doc { return Relax(d, "PackedFunc"); // TODO(@junrushao): verify if this is correct }); TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) .set_dispatch( // - "relax", [](tvm::TupleType n, ffi::reflection::AccessPath n_p, IRDocsifier d) -> Doc { + "relax", [](tvm::TupleType n, AccessPath n_p, IRDocsifier d) -> Doc { if (n->fields.empty()) { return Relax(d, "Tuple"); } ffi::Array fields_doc; - ffi::reflection::AccessPath fields_p = n_p->Attr("fields"); + AccessPath fields_p = n_p->Attr("fields"); for (int i = 0, l = n->fields.size(); i < l; ++i) { fields_doc.push_back(d->AsDoc(n->fields[i], fields_p->ArrayItem(i))); } @@ -68,10 +68,10 @@ TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) .set_dispatch( - "relax", [](tvm::FuncType n, ffi::reflection::AccessPath n_p, IRDocsifier d) -> Doc { + "relax", [](tvm::FuncType n, AccessPath n_p, IRDocsifier d) -> Doc { ffi::Array arg_types_doc; ffi::Array arg_types = n->arg_types; - ffi::reflection::AccessPath arg_types_p = n_p->Attr("arg_types"); + AccessPath arg_types_p = n_p->Attr("arg_types"); for (int i = 0, n_params = arg_types.size(); i < n_params; ++i) { arg_types_doc.push_back(d->AsDoc(arg_types[i], arg_types_p->ArrayItem(i))); } diff --git a/src/relax/script/printer/utils.h b/src/relax/script/printer/utils.h index aabb73abd2cd..607728cb5b69 100644 --- a/src/relax/script/printer/utils.h +++ b/src/relax/script/printer/utils.h @@ -79,8 +79,7 @@ inline IdDoc DefineVar(const relax::Var& var, const Frame& frame, const IRDocsif return d->Define(var, frame, var->name_hint().empty() ? "v" : var->name_hint()); } -inline ffi::Optional StructInfoAsAnn(const relax::Var& v, - const ffi::reflection::AccessPath& v_p, +inline ffi::Optional StructInfoAsAnn(const relax::Var& v, const AccessPath& v_p, const IRDocsifier& d, const ffi::Optional& rhs) { if (!v->struct_info_.defined()) { @@ -134,11 +133,10 @@ inline ffi::Optional StructInfoAsAnn(const relax::Var& v, return d->AsDoc(v->struct_info_, v_p->Attr("struct_info_")); } -ffi::Array PrintSeqExpr(const relax::SeqExpr& n, const ffi::reflection::AccessPath& n_p, +ffi::Array PrintSeqExpr(const relax::SeqExpr& n, const AccessPath& n_p, const IRDocsifier& d, bool use_ret); -ExprDoc PrintShapeVar(const PrimExpr& e, const ffi::reflection::AccessPath& e_p, - const IRDocsifier& d); +ExprDoc PrintShapeVar(const PrimExpr& e, const AccessPath& e_p, const IRDocsifier& d); inline int FindVDeviceIndexByTargetKind(const VDevice& vdevice, const IRDocsifier& d) { ffi::Array vdevices = d->global_infos["vdevice"]; diff --git a/src/relax/transform/legalize_ops.cc b/src/relax/transform/legalize_ops.cc index 6b618608162d..a6d74d91721d 100644 --- a/src/relax/transform/legalize_ops.cc +++ b/src/relax/transform/legalize_ops.cc @@ -236,7 +236,7 @@ class LegalizeMutator : public ExprMutator { Expr VisitExpr_(const CallNode* call) final { Call visited_call = Downcast(this->VisitExprPostOrder_(call)); static const auto& legalize_map = Op::GetAttrMap("FLegalize"); - static const auto& call_packed_map = Op::GetAttrMap("FCallPacked"); + static const auto& call_packed_map = Op::GetAttrMap("FCallPacked"); static const auto& requires_arg_shapes_map = Op::GetAttrMap("RequiresArgumentShapes"); static const Op& call_pure_packed_op = Op::Get("relax.call_pure_packed"); static const Op& call_tir_op = Op::Get("relax.call_tir"); diff --git a/src/s_tir/schedule/primitive/cache_index.cc b/src/s_tir/schedule/primitive/cache_index.cc index e162117b67b0..9566817f8015 100644 --- a/src/s_tir/schedule/primitive/cache_index.cc +++ b/src/s_tir/schedule/primitive/cache_index.cc @@ -289,7 +289,7 @@ ffi::Array MakeIndexCacheStage(IndexInfo* info, const ffi::String& stora // block variables ffi::Array block_vars; // block access region for write buffers - ffi::Array access_region; + Region access_region; // indices used in block body ffi::Array access_indices; ffi::Map block_var_map; diff --git a/src/s_tir/schedule/primitive/cache_read_write.cc b/src/s_tir/schedule/primitive/cache_read_write.cc index d5015b75a318..3c754d1fa3af 100644 --- a/src/s_tir/schedule/primitive/cache_read_write.cc +++ b/src/s_tir/schedule/primitive/cache_read_write.cc @@ -182,17 +182,17 @@ SBlock MakeReindexCacheStage(const BufferRegion& cache_region, ReindexCacheStage } // block access region for read/write buffers - ffi::Array read_access_region, write_access_region; + Region read_access_region, write_access_region; ffi::Array read_access_indices, write_access_indices; // Compute read/write region and read/write access indices. ffi::Array& old_indices = (is_cache_read) ? read_access_indices : write_access_indices; - ffi::Array& old_region = (is_cache_read) ? read_access_region : write_access_region; + Region& old_region = (is_cache_read) ? read_access_region : write_access_region; for (const Range& range : cache_region->region) { old_indices.push_back(Substitute(range->min, var_map)); old_region.push_back(Range::FromMinExtent(old_indices.back(), Integer(1))); } ffi::Array& new_indices = (is_cache_read) ? write_access_indices : read_access_indices; - ffi::Array& new_region = (is_cache_read) ? write_access_region : read_access_region; + Region& new_region = (is_cache_read) ? write_access_region : read_access_region; for (const PrimExpr& idx : info->indices) { new_indices.push_back(Substitute((idx), var_map)); new_region.push_back(Range::FromMinExtent(new_indices.back(), Integer(1))); @@ -254,8 +254,8 @@ SBlock MakeCacheStage(const BufferRegion& cache_region, CacheStageInfo* info, // block variables ffi::Array block_vars; // block access region for read/write buffers - ffi::Array read_access_region; - ffi::Array write_access_region; + Region read_access_region; + Region write_access_region; // indices used in block body ffi::Array read_access_indices; ffi::Array write_access_indices; @@ -384,8 +384,8 @@ SBlock MakeReIndexStage(const SBlock& block, CacheStageInfo* info, // Step 3: Create the reindex block // The src and the dst region and indices of the data copy - ffi::Array src_region{nullptr}; - ffi::Array dst_region{nullptr}; + Region src_region{nullptr}; + Region dst_region{nullptr}; ffi::Array src_indices{nullptr}; ffi::Array dst_indices{nullptr}; @@ -635,7 +635,7 @@ BufferRegion RelaxBufferRegion(ScheduleState self, const BufferRegion& buffer_re /*analyzer=*/&analyzer); TVM_FFI_ICHECK_EQ(buffer_region->region.size(), int_sets.size()); - ffi::Array region; + Region region; region.reserve(int_sets.size()); for (size_t i = 0; i < int_sets.size(); ++i) { region.push_back(int_sets[i].CoverRange(Range::FromMinExtent(0, buffer->shape[i]))); @@ -901,7 +901,7 @@ class CacheReadRewriter : public StmtExprMutator { explicit CacheReadRewriter(const StmtSRef& scope_sref, CacheStageInfo* info, bool cache_full_region = true) : scope_sref_(scope_sref), info_(info), cache_full_region_(cache_full_region) { - auto update_region = [this](const ffi::Array& region, const ffi::Array& offset) -> ffi::Array { + auto update_region = [this](const Region& region, const Region& offset) -> Region { TVM_FFI_ICHECK_EQ(region.size(), offset.size()); std::vector ret; for (size_t i = 0; i < region.size(); ++i) { @@ -1087,7 +1087,7 @@ class ReindexCacheReadRewriter : public CacheReadRewriter { ffi::Array new_reads; for (const BufferRegion& buf_region : reads) { if (buf_region->buffer.same_as(info_->read_buffer)) { - ffi::Array region; + Region region; for (const PrimExpr index : new_indices_) { region.push_back(Range::FromMinExtent(index, Integer(1))); } @@ -1103,7 +1103,7 @@ class ReindexCacheReadRewriter : public CacheReadRewriter { for (const MatchBufferRegion& match_buffer_region : match_buffers) { BufferRegion source = match_buffer_region->source; if (source->buffer.same_as(info_->read_buffer)) { - ffi::Array region; + Region region; for (const PrimExpr index : new_indices_) { region.push_back(Range::FromMinExtent(index, Integer(1))); } @@ -1158,7 +1158,7 @@ class CacheWriteRewriter : public StmtExprMutator { writer_block_sref_(writer_block_sref), info_(info), cache_full_region_(cache_full_region) { - auto update_region = [this](const ffi::Array& region, const ffi::Array& offset) -> ffi::Array { + auto update_region = [this](const Region& region, const Region& offset) -> Region { TVM_FFI_ICHECK_EQ(region.size(), offset.size()); std::vector ret; for (size_t i = 0; i < region.size(); ++i) { @@ -1376,7 +1376,7 @@ class ReindexCacheWriteRewriter : public CacheWriteRewriter { ffi::Array new_reads; for (const BufferRegion& buf_region : reads) { if (buf_region->buffer.same_as(info_->write_buffer)) { - ffi::Array region; + Region region; for (const PrimExpr index : new_indices_) { region.push_back(Range::FromMinExtent(index, Integer(1))); } @@ -1392,7 +1392,7 @@ class ReindexCacheWriteRewriter : public CacheWriteRewriter { for (const MatchBufferRegion& match_buffer_region : match_buffers) { BufferRegion source = match_buffer_region->source; if (source->buffer.same_as(info_->write_buffer)) { - ffi::Array region; + Region region; for (const PrimExpr index : new_indices_) { region.push_back(Range::FromMinExtent(index, Integer(1))); } @@ -1680,7 +1680,7 @@ class ReIndexRewriter : public StmtExprMutator { /*! \brief The new indices */ ffi::Array indices_; /*! \brief The new region */ - ffi::Array region_; + Region region_; }; void CheckRegionCover(const ScheduleState& self, StmtSRef scope_root, Buffer read_buffer) { diff --git a/src/s_tir/schedule/primitive/decompose_padding.cc b/src/s_tir/schedule/primitive/decompose_padding.cc index c7a6ce1ceeb2..ee2045b7eef6 100644 --- a/src/s_tir/schedule/primitive/decompose_padding.cc +++ b/src/s_tir/schedule/primitive/decompose_padding.cc @@ -313,7 +313,7 @@ static std::pair CreateInBoundBlock(const SBlockRealizeNode auto rewrite_expr = [&repl_dict, analyzer](const PrimExpr& e) { return analyzer->Simplify(Substitute(e, repl_dict)); }; - auto rewrite_region = [rewrite_expr](const ffi::Array& region) { + auto rewrite_region = [rewrite_expr](const Region& region) { return region.Map([rewrite_expr](const Range& r) { return Range::FromMinExtent(rewrite_expr(r->min), rewrite_expr(r->extent)); }); diff --git a/src/s_tir/schedule/primitive/rolling_buffer.cc b/src/s_tir/schedule/primitive/rolling_buffer.cc index 5c2b1a985da3..85e4d3b2a8bb 100644 --- a/src/s_tir/schedule/primitive/rolling_buffer.cc +++ b/src/s_tir/schedule/primitive/rolling_buffer.cc @@ -44,7 +44,7 @@ BufferRegion GetRelaxedBufferRegion(const SBlockRealize& realize, const BufferRe const ffi::Map& dom_map) { ffi::Array relaxed_intsets = arith::EvalSet(Substitute(buffer_region->region, GetBindings(realize)), dom_map); - ffi::Array relaxed_region; + Region relaxed_region; relaxed_region.reserve(relaxed_intsets.size()); for (size_t i = 0; i < relaxed_intsets.size(); ++i) { relaxed_region.push_back( @@ -165,7 +165,7 @@ class RollingBufferInfoCollector { private: bool MatchRollingBuffer(const StmtSRef& block_sref, const BufferRegion& buffer_region) { const Buffer& buffer = buffer_region->buffer; - const ffi::Array& region = buffer_region->region; + const Region& region = buffer_region->region; std::vector> bound_iter_vars; std::vector bound_overlaps; diff --git a/src/s_tir/support/nd_int_set.h b/src/s_tir/support/nd_int_set.h index df9aa8e3dc64..03f3672b452d 100644 --- a/src/s_tir/support/nd_int_set.h +++ b/src/s_tir/support/nd_int_set.h @@ -36,7 +36,7 @@ using NDIntSet = std::vector; * \param region The region. * \return The constructed set. */ -inline NDIntSet NDIntSetFromRegion(const ffi::Array& region) { +inline NDIntSet NDIntSetFromRegion(const tirx::Region& region) { NDIntSet result; result.reserve(region.size()); for (const Range& range : region) { diff --git a/src/s_tir/transform/compact_buffer_region.cc b/src/s_tir/transform/compact_buffer_region.cc index c4e68d24bd89..d01f24c670a4 100644 --- a/src/s_tir/transform/compact_buffer_region.cc +++ b/src/s_tir/transform/compact_buffer_region.cc @@ -47,7 +47,7 @@ using namespace tvm::tirx; using support::NDIntSet; /*! \brief a more constrained bound estimate for n-dimentional int set */ -NDIntSet NDIntSetEval(ffi::Array region, PrimExpr predicate, +NDIntSet NDIntSetEval(Region region, PrimExpr predicate, const std::unordered_map& dom_map, arith::Analyzer* analyzer) { std::unordered_map var_dom; @@ -111,7 +111,7 @@ class Var2BufferCollector : public StmtExprVisitor { */ class BufferAccessRegionCollector : public StmtExprVisitor { public: - static std::unordered_map, ffi::ObjectPtrHash, ffi::ObjectPtrEqual> Collect( + static std::unordered_map Collect( const PrimFunc& f, bool collect_inbound) { BufferAccessRegionCollector region_collector(collect_inbound); @@ -528,7 +528,7 @@ class BufferAccessRegionCollector : public StmtExprVisitor { * The entire access region should get updated on the buffer's define point * and we sanity check that every buffer is defined only once. */ - std::unordered_map, ffi::ObjectPtrHash, ffi::ObjectPtrEqual> buffer_access_region_; + std::unordered_map buffer_access_region_; /*! \brief The map from Buffer to it's access regions annotated by current block. */ std::unordered_map, ffi::ObjectPtrHash, ffi::ObjectPtrEqual> @@ -548,7 +548,7 @@ struct DimAlignInfo { struct BufferAllocInfo { /*! \brief The buffer access region. */ - ffi::Array region; + Region region; /*! \brief The storage alignment information. */ std::vector dim_aligns; /*! @@ -644,7 +644,7 @@ class BufferCompactor : public StmtExprMutator { *indices = std::move(new_indices); } - void RewriteBufferRegion(Buffer* buffer, ffi::Array* region) const { + void RewriteBufferRegion(Buffer* buffer, Region* region) const { auto it = buffer_info_.find((*buffer)->data); if (it == buffer_info_.end()) { // Skip if the buffer is parameter @@ -652,7 +652,7 @@ class BufferCompactor : public StmtExprMutator { } const BufferAllocInfo& info = it->second; TVM_FFI_ICHECK_EQ(region->size(), info.region.size()); - ffi::Array new_region; + Region new_region; new_region.reserve(info.region.size()); for (size_t i = 0; i < info.region.size(); ++i) { const Range& range = (*region)[i]; @@ -716,14 +716,14 @@ ffi::Array CalcStrides(const BufferAllocInfo& alloc_info, Stmt BufferCompactorCompact( const PrimFunc& f, - const std::unordered_map, ffi::ObjectPtrHash, ffi::ObjectPtrEqual>& regions, + const std::unordered_map& regions, const std::unordered_map& storage_align) { // collect buffer allocation info for no-alias buffers std::unordered_map buffer_info; for (const auto& kv : regions) { const Buffer& buffer = kv.first; // set dim alignment info - ffi::Array region = kv.second; + Region region = kv.second; BufferAllocInfo alloc_info; auto it = storage_align.find(buffer->data); if (it != storage_align.end()) { diff --git a/src/s_tir/transform/inject_software_pipeline.cc b/src/s_tir/transform/inject_software_pipeline.cc index 14997709b8b5..151264405207 100644 --- a/src/s_tir/transform/inject_software_pipeline.cc +++ b/src/s_tir/transform/inject_software_pipeline.cc @@ -237,7 +237,7 @@ class PipelineBodyRewriter : public StmtExprMutator { BufferRegion RewritePipelineBufferRegion(const BufferRegion& buffer_region) const { auto it = buffer_remap_.find(buffer_region->buffer); if (it != buffer_remap_.end()) { - ffi::Array new_region = buffer_region->region; + Region new_region = buffer_region->region; const Buffer& new_buffer = (*it).second; // For pipeline buffers, relax the access region of the first dimension to full extent // if access_all_versions == true @@ -444,7 +444,7 @@ class PipelineRewriter : public StmtExprMutator { * \param region2 The second region. * \return Whether region1 and region2 have intersections. */ - bool MayConflict(ffi::Array region1, ffi::Array region2) { + bool MayConflict(Region region1, Region region2) { TVM_FFI_ICHECK(region1.size() == region2.size()); for (size_t i = 0; i < region1.size(); i++) { Range dim1 = region1[i]; @@ -1203,7 +1203,7 @@ class PipelineInjector : private StmtExprMutator { void AddAllocBuffers(SBlockNode* n, const ffi::Array alloc_buffers) { for (const Buffer& alloc_buffer : alloc_buffers) { n->alloc_buffers.push_back(alloc_buffer); - ffi::Array region; + Region region; region.reserve(alloc_buffer->shape.size()); for (const PrimExpr& dim : alloc_buffer->shape) { region.push_back(Range::FromMinExtent(0, dim)); diff --git a/src/s_tir/transform/lower_match_buffer.cc b/src/s_tir/transform/lower_match_buffer.cc index ac23bb87d537..4caa02bc713c 100644 --- a/src/s_tir/transform/lower_match_buffer.cc +++ b/src/s_tir/transform/lower_match_buffer.cc @@ -25,11 +25,11 @@ #include #include #include +#include #include #include #include #include -#include #include "../../tirx/ir/functor_common.h" #include "../../tirx/transform/ir_utils.h" @@ -154,7 +154,7 @@ class MatchBufferLower : public StmtExprMutator { return buffer_region; } else { const BufferRegion& source = (*it).second; - ffi::Array region = ConvertRegion(MatchBufferRegion(buffer, source), buffer_region->region); + Region region = ConvertRegion(MatchBufferRegion(buffer, source), buffer_region->region); return BufferRegion(source->buffer, std::move(region)); } } diff --git a/src/s_tir/transform/memhammer_lower_auto_copy.cc b/src/s_tir/transform/memhammer_lower_auto_copy.cc index b3536be6619a..3836256449f9 100644 --- a/src/s_tir/transform/memhammer_lower_auto_copy.cc +++ b/src/s_tir/transform/memhammer_lower_auto_copy.cc @@ -573,7 +573,7 @@ class AutoPadder { Buffer src_buffer = r->source->buffer; runtime::StorageScope scope = runtime::StorageScope::Create(src_buffer.scope()); if (scope.rank == runtime::StorageRank::kShared) { - ffi::Array region = r->source->region; + Region region = r->source->region; ffi::Array indices; for (int i = 0; i < static_cast(region.size()); i++) { Var var("region" + std::to_string(i)); diff --git a/src/script/printer/doc.cc b/src/script/printer/doc.cc index 9bc92a9e1095..5cd9edca79dc 100644 --- a/src/script/printer/doc.cc +++ b/src/script/printer/doc.cc @@ -18,9 +18,9 @@ */ #include #include +#include #include #include -#include #include namespace tvm { @@ -83,7 +83,7 @@ StmtBlockDoc::StmtBlockDoc(ffi::Array stmts) { this->data_ = std::move(n); } -LiteralDoc::LiteralDoc(ffi::Any value, const ffi::Optional& object_path) { +LiteralDoc::LiteralDoc(ffi::Any value, const ffi::Optional& object_path) { ffi::ObjectPtr n = ffi::make_object(); n->value = value; if (object_path.defined()) { @@ -273,7 +273,7 @@ TVM_FFI_STATIC_INIT_BLOCK() { namespace refl = tvm::ffi::reflection; refl::GlobalDef().def( "script.printer.DocSetSourcePaths", - [](Doc doc, ffi::Array source_paths) { doc->source_paths = source_paths; }); + [](Doc doc, ffi::Array source_paths) { doc->source_paths = source_paths; }); } TVM_FFI_STATIC_INIT_BLOCK() { diff --git a/src/script/printer/doc_printer/base_doc_printer.cc b/src/script/printer/doc_printer/base_doc_printer.cc index 30754a66ee51..ad81297f97be 100644 --- a/src/script/printer/doc_printer/base_doc_printer.cc +++ b/src/script/printer/doc_printer/base_doc_printer.cc @@ -264,7 +264,7 @@ DocPrinter::DocPrinter(const PrinterConfig& options) : options_(options) { void DocPrinter::Append(const Doc& doc) { Append(doc, PrinterConfig()); } void DocPrinter::Append(const Doc& doc, const PrinterConfig& cfg) { - for (const ffi::reflection::AccessPath& p : cfg->path_to_underline) { + for (const AccessPath& p : cfg->path_to_underline) { path_to_underline_.push_back(p); current_max_path_depth_.push_back(0); current_underline_candidates_.push_back(std::vector()); @@ -348,15 +348,15 @@ void DocPrinter::PrintDoc(const Doc& doc) { } size_t end_pos = output_.tellp(); - for (const ffi::reflection::AccessPath& path : doc->source_paths) { + for (const AccessPath& path : doc->source_paths) { MarkSpan({start_pos, end_pos}, path); } } -void DocPrinter::MarkSpan(const ByteSpan& span, const ffi::reflection::AccessPath& path) { +void DocPrinter::MarkSpan(const ByteSpan& span, const AccessPath& path) { int n = path_to_underline_.size(); for (int i = 0; i < n; ++i) { - ffi::reflection::AccessPath p = path_to_underline_[i]; + AccessPath p = path_to_underline_[i]; if (path->depth >= current_max_path_depth_[i] && path->IsPrefixOf(p)) { if (path->depth > current_max_path_depth_[i]) { current_max_path_depth_[i] = path->depth; diff --git a/src/script/printer/doc_printer/base_doc_printer.h b/src/script/printer/doc_printer/base_doc_printer.h index cbad586d558e..6708ce156b20 100644 --- a/src/script/printer/doc_printer/base_doc_printer.h +++ b/src/script/printer/doc_printer/base_doc_printer.h @@ -255,7 +255,7 @@ class DocPrinter { std::vector underlines_exempted_; private: - void MarkSpan(const ByteSpan& span, const ffi::reflection::AccessPath& path); + void MarkSpan(const ByteSpan& span, const AccessPath& path); /*! \brief Options to customize certain aspects of the output */ PrinterConfig options_; @@ -267,7 +267,7 @@ class DocPrinter { std::vector line_starts_; /*! \brief Path of the object that we would like to underline */ - ffi::Array path_to_underline_; + ffi::Array path_to_underline_; /*! * \brief Candidate spans to be underlined, until we find a better match. diff --git a/src/script/printer/ir/distributed.cc b/src/script/printer/ir/distributed.cc index 60c0e3ceaf7e..5abc316154e0 100644 --- a/src/script/printer/ir/distributed.cc +++ b/src/script/printer/ir/distributed.cc @@ -24,7 +24,7 @@ namespace script { namespace printer { TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) - .set_dispatch("", [](ffi::Shape n, ffi::reflection::AccessPath n_p, IRDocsifier d) -> Doc { + .set_dispatch("", [](ffi::Shape n, AccessPath n_p, IRDocsifier d) -> Doc { int s = n.size(); ffi::Array results; results.reserve(s); diff --git a/src/script/printer/ir/ir.cc b/src/script/printer/ir/ir.cc index 4029863aeeaa..a9b998d03eb2 100644 --- a/src/script/printer/ir/ir.cc +++ b/src/script/printer/ir/ir.cc @@ -56,7 +56,7 @@ struct SortableFunction { }; TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) - .set_dispatch("", [](IRModule mod, ffi::reflection::AccessPath p, IRDocsifier d) -> Doc { + .set_dispatch("", [](IRModule mod, AccessPath p, IRDocsifier d) -> Doc { std::vector functions; for (const auto& kv : mod->functions) { functions.push_back(SortableFunction(kv)); @@ -113,22 +113,22 @@ TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) }); TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) - .set_dispatch("", [](DictAttrs attrs, ffi::reflection::AccessPath p, IRDocsifier d) -> Doc { + .set_dispatch("", [](DictAttrs attrs, AccessPath p, IRDocsifier d) -> Doc { return d->AsDoc(attrs->dict, p->Attr("dict")); }); TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) - .set_dispatch("", [](GlobalVar gv, ffi::reflection::AccessPath p, IRDocsifier d) -> Doc { + .set_dispatch("", [](GlobalVar gv, AccessPath p, IRDocsifier d) -> Doc { return IR(d, "GlobalVar")->Call({LiteralDoc::Str(gv->name_hint, p->Attr("name_hint"))}); }); TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) - .set_dispatch("", [](GlobalInfo ginfo, ffi::reflection::AccessPath p, IRDocsifier d) -> Doc { + .set_dispatch("", [](GlobalInfo ginfo, AccessPath p, IRDocsifier d) -> Doc { return IR(d, "dummy_global_info")->Call({}); }); TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) - .set_dispatch("", [](VDevice vdev, ffi::reflection::AccessPath p, IRDocsifier d) -> Doc { + .set_dispatch("", [](VDevice vdev, AccessPath p, IRDocsifier d) -> Doc { d->AddGlobalInfo("vdevice", vdev); ffi::Map config = vdev->target->ToConfig(); return IR(d, "vdevice") @@ -138,12 +138,12 @@ TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) }); TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) - .set_dispatch("", [](Op op, ffi::reflection::AccessPath p, IRDocsifier d) -> Doc { + .set_dispatch("", [](Op op, AccessPath p, IRDocsifier d) -> Doc { return IR(d, "Op")->Call({LiteralDoc::Str(op->name, p->Attr("name"))}); }); TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) - .set_dispatch("", [](FuncType func_type, ffi::reflection::AccessPath p, IRDocsifier d) -> Doc { + .set_dispatch("", [](FuncType func_type, AccessPath p, IRDocsifier d) -> Doc { return IR(d, "FuncType") ->Call({ d->AsDoc(func_type->arg_types, p->Attr("arg_types")), @@ -152,7 +152,7 @@ TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) }); TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) - .set_dispatch("ir", [](Range range, ffi::reflection::AccessPath p, IRDocsifier d) -> Doc { + .set_dispatch("ir", [](Range range, AccessPath p, IRDocsifier d) -> Doc { return IR(d, "Range") ->Call({ d->AsDoc(range->min, p->Attr("min")), diff --git a/src/script/printer/ir/misc.cc b/src/script/printer/ir/misc.cc index 64eb69bf5668..f33170577154 100644 --- a/src/script/printer/ir/misc.cc +++ b/src/script/printer/ir/misc.cc @@ -24,7 +24,7 @@ namespace printer { TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) .set_dispatch>( // - "", [](ffi::Array array, ffi::reflection::AccessPath p, IRDocsifier d) -> Doc { + "", [](ffi::Array array, AccessPath p, IRDocsifier d) -> Doc { int n = array.size(); ffi::Array results; results.reserve(n); @@ -36,7 +36,7 @@ TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) .set_dispatch>( // - "", [](ffi::Map dict, ffi::reflection::AccessPath p, IRDocsifier d) -> Doc { + "", [](ffi::Map dict, AccessPath p, IRDocsifier d) -> Doc { using POO = std::pair; std::vector items{dict.begin(), dict.end()}; bool is_str_map = true; diff --git a/src/script/printer/ir_docsifier.cc b/src/script/printer/ir_docsifier.cc index 76631c169c24..dd5762973b73 100644 --- a/src/script/printer/ir_docsifier.cc +++ b/src/script/printer/ir_docsifier.cc @@ -16,10 +16,10 @@ * specific language governing permissions and limitations * under the License. */ +#include #include #include #include -#include #include #include @@ -209,7 +209,7 @@ IRDocsifier::FType& IRDocsifier::vtable() { } TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) - .set_fallback([](ffi::ObjectRef obj, ffi::reflection::AccessPath p, IRDocsifier d) -> Doc { + .set_fallback([](ffi::ObjectRef obj, AccessPath p, IRDocsifier d) -> Doc { return d->AddMetadata(obj); }); diff --git a/src/script/printer/utils.h b/src/script/printer/utils.h index 41bed45a0552..eed29f102dfc 100644 --- a/src/script/printer/utils.h +++ b/src/script/printer/utils.h @@ -25,9 +25,9 @@ #include #include #include +#include #include #include -#include #include #include @@ -60,7 +60,7 @@ inline std::string RedirectedReprPrinterMethod(const ffi::ObjectRef& obj) { inline std::string Docsify(const ffi::ObjectRef& obj, const IRDocsifier& d, const Frame& f, const PrinterConfig& cfg) { - Doc doc = d->AsDoc(obj, ffi::reflection::AccessPath::Root()); + Doc doc = d->AsDoc(obj, AccessPath::Root()); bool move_source_paths = false; if (const auto* expr_doc = doc.as()) { if (!cfg->verbose_expr) { diff --git a/src/target/cuda/intrin_rule_cuda.cc b/src/target/cuda/intrin_rule_cuda.cc index fc2d78a30710..d38db9fe8372 100644 --- a/src/target/cuda/intrin_rule_cuda.cc +++ b/src/target/cuda/intrin_rule_cuda.cc @@ -247,7 +247,7 @@ TVM_REGISTER_OP("tirx.cuda.__shfl_sync") .add_argument("var", "Expr", "The variable to sync.") .add_argument("lane", "Expr", "The source thread id.") .add_argument("width", "Expr", "The warp thread width, must be a power of 2.") - .set_attr("TGlobalSymbol", "__shfl_sync") + .set_attr("TGlobalSymbol", "__shfl_sync") .set_attr("TCallEffectKind", Integer(CallEffectKind::kOpaque)) .set_attr("cuda.need_warp_shuffle", true); @@ -257,7 +257,7 @@ TVM_REGISTER_OP("tirx.cuda.__shfl_up_sync") .add_argument("var", "Expr", "The variable to sync.") .add_argument("delta", "Expr", "The source lane id offset to be added.") .add_argument("width", "Expr", "The warp thread width, must be a power of 2.") - .set_attr("TGlobalSymbol", "__shfl_up_sync") + .set_attr("TGlobalSymbol", "__shfl_up_sync") .set_attr("TCallEffectKind", Integer(CallEffectKind::kOpaque)) .set_attr("cuda.need_warp_shuffle", true); @@ -267,13 +267,13 @@ TVM_REGISTER_OP("tirx.cuda.__shfl_down_sync") .add_argument("var", "Expr", "The variable to sync.") .add_argument("delta", "Expr", "The source lane id offset to be subtracted.") .add_argument("width", "Expr", "The warp thread width, must be a power of 2.") - .set_attr("TGlobalSymbol", "__shfl_down_sync") + .set_attr("TGlobalSymbol", "__shfl_down_sync") .set_attr("TCallEffectKind", Integer(CallEffectKind::kOpaque)) .set_attr("cuda.need_warp_shuffle", true); TVM_REGISTER_OP("tirx.cuda.__activemask") .set_num_inputs(0) - .set_attr("TGlobalSymbol", "__activemask") + .set_attr("TGlobalSymbol", "__activemask") .set_attr("TCallEffectKind", Integer(CallEffectKind::kPure)) .set_attr("cuda.need_warp_shuffle", true); diff --git a/src/target/llvm/codegen_llvm.h b/src/target/llvm/codegen_llvm.h index e6c3176867d5..b57a1a446bcf 100644 --- a/src/target/llvm/codegen_llvm.h +++ b/src/target/llvm/codegen_llvm.h @@ -563,7 +563,7 @@ class CodeGenLLVM : public ExprFunctor, llvm::DISubprogram* di_subprogram_{nullptr}; // Cache potential common path ops to slightly improve lookup time. // global symbol table. - OpAttrMap op_attr_global_symbol_ = Op::GetAttrMap("TGlobalSymbol"); + OpAttrMap op_attr_global_symbol_ = Op::GetAttrMap("TGlobalSymbol"); const Op& builtin_call_extern_ = builtin::call_extern(); const Op& builtin_call_pure_extern_ = builtin::call_pure_extern(); const Op& builtin_call_llvm_intrin_ = builtin::call_llvm_intrin(); diff --git a/src/target/metal/intrin_rule_metal.cc b/src/target/metal/intrin_rule_metal.cc index 94f4c0fbe308..cea19519ca7f 100644 --- a/src/target/metal/intrin_rule_metal.cc +++ b/src/target/metal/intrin_rule_metal.cc @@ -143,21 +143,21 @@ TVM_REGISTER_OP("tirx.metal.simd_shuffle") .set_num_inputs(2) .add_argument("var", "Expr", "The variable to sync.") .add_argument("lane", "Expr", "The source thread id.") - .set_attr("TGlobalSymbol", "simd_shuffle") + .set_attr("TGlobalSymbol", "simd_shuffle") .set_attr("TCallEffectKind", Integer(CallEffectKind::kOpaque)); TVM_REGISTER_OP("tirx.metal.simd_shuffle_up") .set_num_inputs(2) .add_argument("var", "Expr", "The variable to sync.") .add_argument("delta", "Expr", "The source lane id offset to be added.") - .set_attr("TGlobalSymbol", "simd_shuffle_up") + .set_attr("TGlobalSymbol", "simd_shuffle_up") .set_attr("TCallEffectKind", Integer(CallEffectKind::kOpaque)); TVM_REGISTER_OP("tirx.metal.simd_shuffle_down") .set_num_inputs(2) .add_argument("var", "Expr", "The variable to sync.") .add_argument("delta", "Expr", "The source lane id offset to be subtracted.") - .set_attr("TGlobalSymbol", "simd_shuffle_down") + .set_attr("TGlobalSymbol", "simd_shuffle_down") .set_attr("TCallEffectKind", Integer(CallEffectKind::kOpaque)); } // namespace intrin diff --git a/src/target/source/codegen_c.h b/src/target/source/codegen_c.h index 0914abc79dff..29c5e420997e 100644 --- a/src/target/source/codegen_c.h +++ b/src/target/source/codegen_c.h @@ -316,7 +316,7 @@ class CodeGenC : public ExprFunctor, /*! \brief the data type of allocated buffers */ std::unordered_map handle_data_type_; /*! \brief Record of ops that have pre-defined global symbol. */ - OpAttrMap op_attr_global_symbol_ = Op::GetAttrMap("TGlobalSymbol"); + OpAttrMap op_attr_global_symbol_ = Op::GetAttrMap("TGlobalSymbol"); // cache commonly used ops const Op& builtin_call_extern_ = builtin::call_extern(); const Op& builtin_call_pure_extern_ = builtin::call_pure_extern(); diff --git a/src/target/virtual_device.cc b/src/target/virtual_device.cc index 6c83acbe4e2f..c7357d7f14f7 100644 --- a/src/target/virtual_device.cc +++ b/src/target/virtual_device.cc @@ -68,7 +68,7 @@ TVM_FFI_STATIC_INIT_BLOCK() { } VirtualDevice::VirtualDevice(int device_type_int, int virtual_device_id, Target target, - ffi::String memory_scope) { + MemoryScope memory_scope) { TVM_FFI_ICHECK(!target.defined() || device_type_int == target->GetTargetDeviceType()) << "target " << target->str() << " has device type " << target->GetTargetDeviceType() << " but virtual device has device type " << device_type_int; @@ -118,7 +118,7 @@ ffi::Optional VirtualDevice::Join(const VirtualDevice& lhs, } else { joined_target = rhs->target; } - ffi::String joined_memory_scope; + MemoryScope joined_memory_scope; if (!lhs->memory_scope.empty()) { joined_memory_scope = lhs->memory_scope; if (!rhs->memory_scope.empty() && lhs->memory_scope != rhs->memory_scope) { @@ -158,7 +158,7 @@ VirtualDevice VirtualDevice::Default(const VirtualDevice& lhs, const VirtualDevi } // else: leave as null } - ffi::String defaulted_memory_scope; + MemoryScope defaulted_memory_scope; if (!lhs->memory_scope.empty()) { defaulted_memory_scope = lhs->memory_scope; } else { @@ -169,7 +169,7 @@ VirtualDevice VirtualDevice::Default(const VirtualDevice& lhs, const VirtualDevi } VirtualDevice VirtualDeviceCache::Make(int device_type, int virtual_device_id, Target target, - ffi::String memory_scope) { + MemoryScope memory_scope) { VirtualDevice prototype(device_type, virtual_device_id, std::move(target), std::move(memory_scope)); if (prototype->IsFullyUnconstrained()) { diff --git a/src/target/webgpu/intrin_rule_webgpu.cc b/src/target/webgpu/intrin_rule_webgpu.cc index ce958a466d6c..bc48395468d3 100644 --- a/src/target/webgpu/intrin_rule_webgpu.cc +++ b/src/target/webgpu/intrin_rule_webgpu.cc @@ -163,21 +163,21 @@ TVM_REGISTER_OP("tirx.webgpu.subgroup_shuffle") .set_num_inputs(2) .add_argument("var", "Expr", "The variable to sync.") .add_argument("lane", "Expr", "The source thread id.") - .set_attr("TGlobalSymbol", "subgroupShuffle") + .set_attr("TGlobalSymbol", "subgroupShuffle") .set_attr("TCallEffectKind", Integer(CallEffectKind::kOpaque)); TVM_REGISTER_OP("tirx.webgpu.subgroup_shuffle_up") .set_num_inputs(2) .add_argument("var", "Expr", "The variable to sync.") .add_argument("delta", "Expr", "The source lane id offset to be added.") - .set_attr("TGlobalSymbol", "subgroupShuffleUp") + .set_attr("TGlobalSymbol", "subgroupShuffleUp") .set_attr("TCallEffectKind", Integer(CallEffectKind::kOpaque)); TVM_REGISTER_OP("tirx.webgpu.subgroup_shuffle_down") .set_num_inputs(2) .add_argument("var", "Expr", "The variable to sync.") .add_argument("delta", "Expr", "The source lane id offset to be subtracted.") - .set_attr("TGlobalSymbol", "subgroupShuffleDown") + .set_attr("TGlobalSymbol", "subgroupShuffleDown") .set_attr("TCallEffectKind", Integer(CallEffectKind::kOpaque)); } // namespace intrin diff --git a/src/tirx/op/builtin.cc b/src/tirx/op/builtin.cc index 3ba16d9f9cf3..4355583d796b 100644 --- a/src/tirx/op/builtin.cc +++ b/src/tirx/op/builtin.cc @@ -211,11 +211,11 @@ TIR_DEFINE_BUILTIN_FUNC(tvm_stack_make_array) // When num_inputs are not set, the function is assumed to be variable length. TIR_DEFINE_BUILTIN_FUNC(tvm_call_packed) .set_attr("TCallEffectKind", Integer(CallEffectKind::kOpaque)) - .set_attr("TScriptPrinterName", ffi::String("call_packed"), /*plevel=*/20); + .set_attr("TScriptPrinterName", ffi::String("call_packed"), /*plevel=*/20); TIR_DEFINE_BUILTIN_FUNC(tvm_call_cpacked) .set_attr("TCallEffectKind", Integer(CallEffectKind::kOpaque)) - .set_attr("TScriptPrinterName", ffi::String("call_cpacked"), /*plevel=*/20); + .set_attr("TScriptPrinterName", ffi::String("call_cpacked"), /*plevel=*/20); TIR_DEFINE_BUILTIN_FUNC(tvm_call_trace_packed) .set_attr("TCallEffectKind", Integer(CallEffectKind::kOpaque)); @@ -226,12 +226,12 @@ TIR_DEFINE_BUILTIN_FUNC(tvm_thread_invariant) TIR_DEFINE_BUILTIN_FUNC(tvm_call_packed_lowered) .set_attr("TCallEffectKind", Integer(CallEffectKind::kOpaque)) - .set_attr("TScriptPrinterName", ffi::String("call_packed_lowered"), + .set_attr("TScriptPrinterName", ffi::String("call_packed_lowered"), /*plevel=*/20); TIR_DEFINE_BUILTIN_FUNC(tvm_call_cpacked_lowered) .set_attr("TCallEffectKind", Integer(CallEffectKind::kOpaque)) - .set_attr("TScriptPrinterName", ffi::String("call_cpacked_lowered"), + .set_attr("TScriptPrinterName", ffi::String("call_cpacked_lowered"), /*plevel=*/20); TIR_DEFINE_BUILTIN_FUNC(tvm_call_trace_packed_lowered) @@ -410,7 +410,7 @@ TIR_DEFINE_BUILTIN_FUNC(anylist_getitem) TIR_DEFINE_BUILTIN_FUNC(anylist_resetitem) .set_attr("TCallEffectKind", Integer(CallEffectKind::kOpaque)) - .set_attr("TGlobalSymbol", "TVMBackendAnyListResetItem"); + .set_attr("TGlobalSymbol", "TVMBackendAnyListResetItem"); TIR_DEFINE_BUILTIN_FUNC(anylist_setitem_call_packed) .set_attr("TCallEffectKind", Integer(CallEffectKind::kOpaque)); diff --git a/src/tirx/op/op.cc b/src/tirx/op/op.cc index fa3953664bed..91539c9e7c28 100644 --- a/src/tirx/op/op.cc +++ b/src/tirx/op/op.cc @@ -25,11 +25,11 @@ #include #include +#include #include #include #include #include -#include #include // Centralized header for constant folders. @@ -1146,12 +1146,12 @@ TVM_TIR_REGISTER_PURE_BINARY_OP("ldexp"); TVM_TIR_REGISTER_OP("TVMBackendAllocWorkspace") .set_num_inputs(5) - .set_attr("TGlobalSymbol", "TVMBackendAllocWorkspace") + .set_attr("TGlobalSymbol", "TVMBackendAllocWorkspace") .set_attr("TCallEffectKind", Integer(CallEffectKind::kOpaque)); TVM_TIR_REGISTER_OP("TVMBackendFreeWorkspace") .set_num_inputs(3) - .set_attr("TGlobalSymbol", "TVMBackendFreeWorkspace") + .set_attr("TGlobalSymbol", "TVMBackendFreeWorkspace") .set_attr("TCallEffectKind", Integer(CallEffectKind::kOpaque)); // expose basic functions to node namespace diff --git a/src/tirx/op/runtime.cc b/src/tirx/op/runtime.cc index 148c2b9c132e..e013b21d6676 100644 --- a/src/tirx/op/runtime.cc +++ b/src/tirx/op/runtime.cc @@ -29,12 +29,12 @@ namespace tirx { TVM_REGISTER_OP("tirx.TVMBackendAnyListSetPackedArg") .set_num_inputs(5) - .set_attr("TGlobalSymbol", "TVMBackendAnyListSetPackedArg") + .set_attr("TGlobalSymbol", "TVMBackendAnyListSetPackedArg") .set_attr("TCallEffectKind", Integer(CallEffectKind::kOpaque)); TVM_REGISTER_OP("tirx.TVMBackendAnyListMoveFromPackedReturn") .set_num_inputs(3) - .set_attr("TGlobalSymbol", "TVMBackendAnyListMoveFromPackedReturn") + .set_attr("TGlobalSymbol", "TVMBackendAnyListMoveFromPackedReturn") .set_attr("TCallEffectKind", Integer(CallEffectKind::kOpaque)); } // namespace tirx diff --git a/src/tirx/script/printer/block.cc b/src/tirx/script/printer/block.cc index 3f7c86c49ff6..6c86d68ff5f4 100644 --- a/src/tirx/script/printer/block.cc +++ b/src/tirx/script/printer/block.cc @@ -22,14 +22,14 @@ namespace tvm { namespace script { namespace printer { -Doc PrintBlock(IRDocsifier d, tirx::SBlock block, ffi::reflection::AccessPath block_p, // +Doc PrintBlock(IRDocsifier d, tirx::SBlock block, AccessPath block_p, // ffi::Optional opt_realize, - ffi::Optional opt_realize_p) { + ffi::Optional opt_realize_p) { With frame(d, block); TVM_FFI_ICHECK_EQ(opt_realize.defined(), opt_realize_p.defined()); const tirx::SBlockRealizeNode* realize = opt_realize.defined() ? opt_realize.value().get() : nullptr; - ffi::reflection::AccessPath realize_p = *opt_realize_p; + AccessPath realize_p = *opt_realize_p; // Step 1. Handle block var and block bindings // Step 1.1. Obtain all loop var defined along path std::unordered_map loop_vars; @@ -69,7 +69,7 @@ Doc PrintBlock(IRDocsifier d, tirx::SBlock block, ffi::reflection::AccessPath bl auto print_single_iter_var = [&](int i) { tirx::IterVar iter_var = block->iter_vars[i]; - ffi::reflection::AccessPath iter_var_p = block_p->Attr("iter_var")->ArrayItem(i); + AccessPath iter_var_p = block_p->Attr("iter_var")->ArrayItem(i); ExprDoc rhs = TIR(d, "axis"); if (iter_var->iter_type == tirx::IterVarType::kDataPar) { rhs = rhs->Attr("spatial"); @@ -120,10 +120,10 @@ Doc PrintBlock(IRDocsifier d, tirx::SBlock block, ffi::reflection::AccessPath bl lhs.reserve(m); loop_var_doc.reserve(m); std::string binding_type = ""; - ffi::Array binding_paths; + ffi::Array binding_paths; for (int i : remap_vars_indices) { tirx::IterVar iter_var = block->iter_vars[i]; - ffi::reflection::AccessPath iter_var_p = block_p->Attr("iter_vars")->ArrayItem(i); + AccessPath iter_var_p = block_p->Attr("iter_vars")->ArrayItem(i); lhs.push_back(DefineVar(iter_var->var, *frame, d)); loop_var_doc.push_back(d->AsDoc(realize->iter_values[i], realize_p->Attr("iter_values")->ArrayItem(i))); @@ -180,7 +180,7 @@ Doc PrintBlock(IRDocsifier d, tirx::SBlock block, ffi::reflection::AccessPath bl // Step 5. Handle `alloc_buffer` for (int i = 0, n = block->alloc_buffers.size(); i < n; ++i) { tirx::Buffer buffer = block->alloc_buffers[i]; - ffi::reflection::AccessPath buffer_p = block_p->Attr("alloc_buffers")->ArrayItem(i); + AccessPath buffer_p = block_p->Attr("alloc_buffers")->ArrayItem(i); IdDoc lhs = DefineBuffer(buffer, *frame, d); ExprDoc rhs = BufferDecl(buffer, "sblock_alloc_buffer", {}, buffer_p, *frame, d, BufferVarDefinition::DataPointer); @@ -189,7 +189,7 @@ Doc PrintBlock(IRDocsifier d, tirx::SBlock block, ffi::reflection::AccessPath bl // Step 6. Handle `match_buffer` for (int i = 0, n = block->match_buffers.size(); i < n; ++i) { tirx::MatchBufferRegion buffer_region = block->match_buffers[i]; - ffi::reflection::AccessPath buffer_region_p = block_p->Attr("match_buffers")->ArrayItem(i); + AccessPath buffer_region_p = block_p->Attr("match_buffers")->ArrayItem(i); StmtDoc doc = d->AsDoc(buffer_region, buffer_region_p); (*frame)->stmts.push_back(doc); } @@ -218,7 +218,7 @@ Doc PrintBlock(IRDocsifier d, tirx::SBlock block, ffi::reflection::AccessPath bl TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) .set_dispatch( - "", [](tirx::SBlockRealize realize, ffi::reflection::AccessPath p, IRDocsifier d) -> Doc { + "", [](tirx::SBlockRealize realize, AccessPath p, IRDocsifier d) -> Doc { Doc doc = PrintBlock(d, realize->block, p->Attr("block"), realize, p); // since we do not have d->AsDoc for realize->block, // we should add possible doc decoration manually. @@ -227,7 +227,7 @@ TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) }); TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) - .set_dispatch("", [](tirx::SBlock block, ffi::reflection::AccessPath p, IRDocsifier d) -> Doc { + .set_dispatch("", [](tirx::SBlock block, AccessPath p, IRDocsifier d) -> Doc { return PrintBlock(d, block, p, std::nullopt, std::nullopt); }); diff --git a/src/tirx/script/printer/buffer.cc b/src/tirx/script/printer/buffer.cc index 53149cdd7041..eb34153557ed 100644 --- a/src/tirx/script/printer/buffer.cc +++ b/src/tirx/script/printer/buffer.cc @@ -24,7 +24,7 @@ namespace tvm { namespace script { namespace printer { -ffi::Map BufferAttrs(tirx::Buffer buffer, const ffi::reflection::AccessPath& buffer_p, +ffi::Map BufferAttrs(tirx::Buffer buffer, const AccessPath& buffer_p, const Frame& frame, const IRDocsifier& d, BufferVarDefinition var_definitions) { using tvm::tirx::Var; @@ -53,14 +53,14 @@ ffi::Map BufferAttrs(tirx::Buffer buffer, const ffi::refle auto is_new_var = [&](const PrimExpr& e) { return e->IsInstance() && !d->IsVarDefined(e); }; - auto add_out_of_line_var_def = [&](const Var& var, const ffi::reflection::AccessPath& var_p) { + auto add_out_of_line_var_def = [&](const Var& var, const AccessPath& var_p) { TVM_FFI_ICHECK(!d->IsVarDefined(var)); ExprDoc lhs = DefineVar(var, frame, d); lhs->source_paths.push_back(var_p); var_def_lhs.push_back(lhs); var_def_rhs.push_back(PrintVarCreation(var, var_p, d)); }; - auto try_inline_def = [&](const PrimExpr& e, const ffi::reflection::AccessPath& e_p, + auto try_inline_def = [&](const PrimExpr& e, const AccessPath& e_p, std::function inline_f) { TVM_FFI_ICHECK(is_new_var(e)); Var var = Downcast(e); @@ -75,13 +75,13 @@ ffi::Map BufferAttrs(tirx::Buffer buffer, const ffi::refle // Step 1. Handle `buffer.shape` { const ffi::Array& shape = buffer->shape; - ffi::reflection::AccessPath shape_p = buffer_p->Attr("shape"); + AccessPath shape_p = buffer_p->Attr("shape"); int n = shape.size(); ffi::Array results; results.reserve(n); for (int i = 0; i < n; ++i) { PrimExpr e = shape[i]; - ffi::reflection::AccessPath e_p = shape_p->ArrayItem(i); + AccessPath e_p = shape_p->ArrayItem(i); if (is_new_var(e)) { add_out_of_line_var_def(Downcast(e), e_p); } @@ -110,13 +110,13 @@ ffi::Map BufferAttrs(tirx::Buffer buffer, const ffi::refle // Step 4. Handle `buffer.strides` if (!buffer->strides.empty()) { const ffi::Array& strides = buffer->strides; - ffi::reflection::AccessPath strides_p = buffer_p->Attr("strides"); + AccessPath strides_p = buffer_p->Attr("strides"); int n = strides.size(); ffi::Array results; results.reserve(n); for (int i = 0; i < n; ++i) { PrimExpr e = strides[i]; - ffi::reflection::AccessPath e_p = strides_p->ArrayItem(i); + AccessPath e_p = strides_p->ArrayItem(i); if (is_new_var(e)) { if (try_inline_def(e, e_p, [=]() { return d->AsDoc(buffer, buffer_p) @@ -203,14 +203,14 @@ ExprDoc BufferCall(const ExprDoc& prefix, const ffi::Map& } ExprDoc BufferDecl(const tirx::Buffer& buffer, const ffi::String& method, - const ffi::Array& args, const ffi::reflection::AccessPath& p, const Frame& frame, + const ffi::Array& args, const AccessPath& p, const Frame& frame, const IRDocsifier& d, BufferVarDefinition var_definitions) { return BufferCall(/*prefix=*/TIR(d, method), /*attrs=*/BufferAttrs(buffer, p, frame, d, var_definitions), /*args=*/args); } -ExprDoc BufferAttn(const tirx::Buffer& buffer, const ffi::reflection::AccessPath& p, const Frame& frame, +ExprDoc BufferAttn(const tirx::Buffer& buffer, const AccessPath& p, const Frame& frame, const IRDocsifier& d) { ffi::Map attrs = BufferAttrs(buffer, p, frame, d, BufferVarDefinition::DataPointer); @@ -220,7 +220,7 @@ ExprDoc BufferAttn(const tirx::Buffer& buffer, const ffi::reflection::AccessPath return TIR(d, "Buffer")->Call({shape, dtype}, {}, {}); } -ffi::Array BufferIndices(const ffi::Array& indices, const ffi::reflection::AccessPath& p, +ffi::Array BufferIndices(const ffi::Array& indices, const AccessPath& p, const IRDocsifier& d) { int n = indices.size(); ffi::Array indices_doc; @@ -228,8 +228,8 @@ ffi::Array BufferIndices(const ffi::Array& indices, const ffi::re for (int i = 0; i < n; ++i) { if (const auto* ramp = indices[i].as()) { if (const auto* stride = ramp->stride.as()) { - ffi::reflection::AccessPath ramp_p = p->Attr("indices")->ArrayItem(i); - ffi::reflection::AccessPath stride_p = ramp_p->Attr("stride"); + AccessPath ramp_p = p->Attr("indices")->ArrayItem(i); + AccessPath stride_p = ramp_p->Attr("stride"); ExprDoc start = d->AsDoc(ramp->base, // ramp_p->Attr("base")); ExprDoc stop = d->AsDoc(ramp->base + ramp->lanes * ramp->stride, // @@ -247,14 +247,14 @@ ffi::Array BufferIndices(const ffi::Array& indices, const ffi::re return indices_doc; } -ffi::Array BufferSlices(const ffi::Array& region, const ffi::reflection::AccessPath& p, +ffi::Array BufferSlices(const ffi::Array& region, const AccessPath& p, const IRDocsifier& d) { int n = region.size(); ffi::Array indices; indices.reserve(n); for (int i = 0; i < n; ++i) { Range range = region[i]; - ffi::reflection::AccessPath range_p = p->ArrayItem(i); + AccessPath range_p = p->ArrayItem(i); ExprDoc min = d->AsDoc(range->min, range_p->Attr("min")); if (tirx::is_one(range->extent)) { indices.push_back(min); @@ -268,14 +268,14 @@ ffi::Array BufferSlices(const ffi::Array& region, const ffi::reflect TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) .set_dispatch( - "", [](tirx::BufferRegion buffer_region, ffi::reflection::AccessPath p, IRDocsifier d) -> Doc { + "", [](tirx::BufferRegion buffer_region, AccessPath p, IRDocsifier d) -> Doc { ExprDoc prefix = d->AsDoc(buffer_region->buffer, p->Attr("buffer")); return prefix[BufferSlices(buffer_region->region, p->Attr("region"), d)]; }); TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) .set_dispatch( // - "", [](tirx::BufferStore store, ffi::reflection::AccessPath p, IRDocsifier d) -> Doc { + "", [](tirx::BufferStore store, AccessPath p, IRDocsifier d) -> Doc { ExprDoc buffer = d->AsDoc(store->buffer, p->Attr("buffer")); ExprDoc value = d->AsDoc(store->value, p->Attr("value")); @@ -294,7 +294,7 @@ TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) .set_dispatch( // - "", [](tirx::BufferLoad load, ffi::reflection::AccessPath p, IRDocsifier d) -> Doc { + "", [](tirx::BufferLoad load, AccessPath p, IRDocsifier d) -> Doc { ExprDoc buffer = d->AsDoc(load->buffer, p->Attr("buffer")); // Use .vload(...) syntax when there is a predicate @@ -308,7 +308,7 @@ TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) }); TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) // - .set_dispatch("", [](tirx::Buffer buffer, ffi::reflection::AccessPath p, IRDocsifier d) -> Doc { + .set_dispatch("", [](tirx::Buffer buffer, AccessPath p, IRDocsifier d) -> Doc { if (!d->IsVarDefined(buffer)) { if (ffi::Optional opt_f = FindLowestVarDef(buffer, d)) { ExprDoc lhs = DefineBuffer(buffer, opt_f.value(), d); @@ -326,7 +326,7 @@ TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) // TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) .set_dispatch( - "", [](tirx::MatchBufferRegion stmt, ffi::reflection::AccessPath p, IRDocsifier d) -> Doc { + "", [](tirx::MatchBufferRegion stmt, AccessPath p, IRDocsifier d) -> Doc { Frame frame = d->frames.back(); ExprDoc lhs = DefineBuffer(stmt->buffer, frame, d); ExprDoc src_buffer = d->AsDoc(stmt->source, p->Attr("source")); @@ -337,7 +337,7 @@ TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) .set_dispatch( // - "", [](tirx::ProducerLoad load, ffi::reflection::AccessPath p, IRDocsifier d) -> Doc { + "", [](tirx::ProducerLoad load, AccessPath p, IRDocsifier d) -> Doc { ExprDoc prefix = IdDoc(load->producer->GetNameHint()); return prefix[BufferIndices(load->indices, p->Attr("indices"), d)]; }); diff --git a/src/tirx/script/printer/expr.cc b/src/tirx/script/printer/expr.cc index 6149c208d11f..d9902eb3aab0 100644 --- a/src/tirx/script/printer/expr.cc +++ b/src/tirx/script/printer/expr.cc @@ -16,8 +16,8 @@ * specific language governing permissions and limitations * under the License. */ -#include #include +#include #include "./utils.h" @@ -25,9 +25,9 @@ namespace tvm { namespace script { namespace printer { -ExprDoc PrintVarCreation(const tirx::Var& var, const ffi::reflection::AccessPath& var_p, const IRDocsifier& d) { +ExprDoc PrintVarCreation(const tirx::Var& var, const AccessPath& var_p, const IRDocsifier& d) { Type type = var->type_annotation; - ffi::reflection::AccessPath type_p = var_p->Attr("type_annotation"); + AccessPath type_p = var_p->Attr("type_annotation"); ExprDoc rhs{ffi::UnsafeInit()}; ffi::Array kwargs_keys; ffi::Array kwargs_values; @@ -65,7 +65,7 @@ ExprDoc PrintVarCreation(const tirx::Var& var, const ffi::reflection::AccessPath return rhs; } -Doc PrintVar(const tirx::Var& var, const ffi::reflection::AccessPath& var_p, const IRDocsifier& d) { +Doc PrintVar(const tirx::Var& var, const AccessPath& var_p, const IRDocsifier& d) { if (!d->IsVarDefined(var)) { if (ffi::Optional opt_f = FindLowestVarDef(var, d)) { ExprDoc lhs = DefineVar(var, opt_f.value(), d); @@ -83,17 +83,17 @@ Doc PrintVar(const tirx::Var& var, const ffi::reflection::AccessPath& var_p, con } TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) // - .set_dispatch("", [](tirx::Var var, ffi::reflection::AccessPath p, IRDocsifier d) -> Doc { + .set_dispatch("", [](tirx::Var var, AccessPath p, IRDocsifier d) -> Doc { return PrintVar(var, p, d); }); TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) // - .set_dispatch("", [](tirx::SizeVar var, ffi::reflection::AccessPath p, IRDocsifier d) -> Doc { + .set_dispatch("", [](tirx::SizeVar var, AccessPath p, IRDocsifier d) -> Doc { return PrintVar(var, p, d); }); TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) - .set_dispatch("", [](tirx::IterVar var, ffi::reflection::AccessPath var_p, IRDocsifier d) -> Doc { + .set_dispatch("", [](tirx::IterVar var, AccessPath var_p, IRDocsifier d) -> Doc { return TIR(d, "iter_var") ->Call({ d->AsDoc(var->var, var_p->Attr("var")), @@ -104,7 +104,7 @@ TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) }); TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) - .set_dispatch("", [](tirx::Not node, ffi::reflection::AccessPath p, IRDocsifier d) -> Doc { + .set_dispatch("", [](tirx::Not node, AccessPath p, IRDocsifier d) -> Doc { ExprDoc a = d->AsDoc(node->a, p->Attr("a")); if (a->IsInstance()) { return TIR(d, "Not")->Call({a}); @@ -113,7 +113,7 @@ TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) }); TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) - .set_dispatch("", [](tirx::StringImm s, ffi::reflection::AccessPath p, IRDocsifier d) -> Doc { + .set_dispatch("", [](tirx::StringImm s, AccessPath p, IRDocsifier d) -> Doc { if (HasMultipleLines(s->value)) { return d->AddMetadata(s); } else { @@ -122,14 +122,14 @@ TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) }); TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) - .set_dispatch("", [](tirx::Cast cast, ffi::reflection::AccessPath p, IRDocsifier d) -> Doc { + .set_dispatch("", [](tirx::Cast cast, AccessPath p, IRDocsifier d) -> Doc { ExprDoc dtype = LiteralDoc::DataType(cast->dtype, p->Attr("dtype")); ExprDoc value = d->AsDoc(cast->value, p->Attr("value")); return TIR(d, "Cast")->Call({dtype, value}); }); TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) - .set_dispatch("", [](tirx::Select select, ffi::reflection::AccessPath p, IRDocsifier d) -> Doc { + .set_dispatch("", [](tirx::Select select, AccessPath p, IRDocsifier d) -> Doc { return TIR(d, "Select") ->Call({ d->AsDoc(select->condition, p->Attr("condition")), @@ -139,7 +139,7 @@ TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) }); TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) - .set_dispatch("", [](tirx::Ramp ramp, ffi::reflection::AccessPath ramp_p, IRDocsifier d) -> Doc { + .set_dispatch("", [](tirx::Ramp ramp, AccessPath ramp_p, IRDocsifier d) -> Doc { return TIR(d, "Ramp")->Call({ d->AsDoc(ramp->base, ramp_p->Attr("base")), d->AsDoc(ramp->stride, ramp_p->Attr("stride")), @@ -149,7 +149,7 @@ TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) .set_dispatch("", - [](tirx::Broadcast bc, ffi::reflection::AccessPath bc_p, IRDocsifier d) -> Doc { + [](tirx::Broadcast bc, AccessPath bc_p, IRDocsifier d) -> Doc { return TIR(d, "Broadcast") ->Call({ d->AsDoc(bc->value, bc_p->Attr("value")), @@ -159,7 +159,7 @@ TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) .set_dispatch( // - "", [](tirx::Shuffle shuffle, ffi::reflection::AccessPath p, IRDocsifier d) -> Doc { + "", [](tirx::Shuffle shuffle, AccessPath p, IRDocsifier d) -> Doc { return TIR(d, "Shuffle") ->Call({ d->AsDoc(shuffle->vectors, p->Attr("vectors")), @@ -169,7 +169,7 @@ TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) .set_dispatch( // - "", [](tirx::CommReducer r, ffi::reflection::AccessPath p, IRDocsifier d) -> Doc { + "", [](tirx::CommReducer r, AccessPath p, IRDocsifier d) -> Doc { TVM_FFI_ICHECK_EQ(r->lhs.size(), r->rhs.size()); ffi::Optional lambda; { @@ -200,8 +200,8 @@ TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) }); LambdaDoc PrintIndexMap(const ffi::ObjectRef& map, const ffi::Array& vs, - const ffi::reflection::AccessPath& vs_p, const ffi::Array& es, - const ffi::reflection::AccessPath& es_p, const IRDocsifier& d) { + const AccessPath& vs_p, const ffi::Array& es, + const AccessPath& es_p, const IRDocsifier& d) { With f(d, map); ffi::Array vars; for (int i = 0, l = vs.size(); i < l; ++i) { @@ -216,7 +216,7 @@ LambdaDoc PrintIndexMap(const ffi::ObjectRef& map, const ffi::Array& TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) .set_dispatch( // - "", [](tirx::IndexMap m, ffi::reflection::AccessPath m_p, IRDocsifier d) -> Doc { + "", [](tirx::IndexMap m, AccessPath m_p, IRDocsifier d) -> Doc { LambdaDoc map = PrintIndexMap(m, m->initial_indices, m_p->Attr("initial_indices"), m->final_indices, m_p->Attr("final_indices"), d); if (m->inverse_index_map.defined()) { @@ -232,7 +232,7 @@ TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) }); TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) - .set_dispatch("", [](tirx::Let let, ffi::reflection::AccessPath p, IRDocsifier d) -> Doc { + .set_dispatch("", [](tirx::Let let, AccessPath p, IRDocsifier d) -> Doc { DictDoc where({d->AsDoc(let->var, p->Attr("var"))}, {d->AsDoc(let->value, p->Attr("value"))}); return TIR(d, "Let")->Call({d->AsDoc(let->body, p->Attr("body"))}, // @@ -240,9 +240,9 @@ TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) }); TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) - .set_dispatch("", [](tirx::Call call, ffi::reflection::AccessPath call_p, IRDocsifier d) -> Doc { - static const OpAttrMap& op_names = - Op::GetAttrMap("TScriptPrinterName"); + .set_dispatch("", [](tirx::Call call, AccessPath call_p, IRDocsifier d) -> Doc { + static const OpAttrMap& op_names = + Op::GetAttrMap("TScriptPrinterName"); static const OpAttrMap dtype_locations = Op::GetAttrMap("TScriptDtypePrintLocation"); tirx::ScriptDtypePrintLocation dtype_print_location = tirx::ScriptDtypePrintLocation::kNone; @@ -305,7 +305,7 @@ TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) }); TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) - .set_dispatch("", [](tirx::Reduce r, ffi::reflection::AccessPath p, IRDocsifier d) -> Doc { + .set_dispatch("", [](tirx::Reduce r, AccessPath p, IRDocsifier d) -> Doc { ExprDoc combiner = d->AsDoc(r->combiner, p->Attr("combiner")); ExprDoc source = d->AsDoc(r->source, p->Attr("source")); ExprDoc init = d->AsDoc(r->init, p->Attr("init")); @@ -321,7 +321,7 @@ TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) #define TVM_SCRIPT_PRINTER_DEF_BINARY(NodeType, OpString) \ TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) \ .set_dispatch("", \ - [](tirx::NodeType node, ffi::reflection::AccessPath p, IRDocsifier d) -> Doc { \ + [](tirx::NodeType node, AccessPath p, IRDocsifier d) -> Doc { \ ExprDoc a = d->AsDoc(node->a, p->Attr("a")); \ ExprDoc b = d->AsDoc(node->b, p->Attr("b")); \ return TIR(d, OpString)->Call({a, b}); \ @@ -337,7 +337,7 @@ bool IsNumber(const ExprDoc& e) { } TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) - .set_dispatch("", [](tirx::Div node, ffi::reflection::AccessPath p, IRDocsifier d) -> Doc { + .set_dispatch("", [](tirx::Div node, AccessPath p, IRDocsifier d) -> Doc { ExprDoc a = d->AsDoc(node->a, p->Attr("a")); ExprDoc b = d->AsDoc(node->b, p->Attr("b")); PrimExpr ret = tvm::div(node->a, node->b); @@ -354,7 +354,7 @@ TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) #define TVM_SCRIPT_PRINTER_DEF_BINARY_WITH_SUGAR(NodeType, NodeObj, NodeFunc, OpString, OpKind) \ TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) \ .set_dispatch( \ - "", [](tirx::NodeType node, ffi::reflection::AccessPath p, IRDocsifier d) -> Doc { \ + "", [](tirx::NodeType node, AccessPath p, IRDocsifier d) -> Doc { \ ExprDoc a = d->AsDoc(node->a, p->Attr("a")); \ ExprDoc b = d->AsDoc(node->b, p->Attr("b")); \ PrimExpr ret = tvm::NodeFunc(node->a, node->b); \ diff --git a/src/tirx/script/printer/for_loop.cc b/src/tirx/script/printer/for_loop.cc index d0ea417a3635..9897dd2189b9 100644 --- a/src/tirx/script/printer/for_loop.cc +++ b/src/tirx/script/printer/for_loop.cc @@ -23,7 +23,7 @@ namespace script { namespace printer { TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) - .set_dispatch("", [](tirx::For loop, ffi::reflection::AccessPath loop_p, IRDocsifier d) -> Doc { + .set_dispatch("", [](tirx::For loop, AccessPath loop_p, IRDocsifier d) -> Doc { // Step 1. Check syntactic sugar: `T.grid` std::vector grid; std::unordered_set grid_loop_vars; diff --git a/src/tirx/script/printer/function.cc b/src/tirx/script/printer/function.cc index 15937d721cb1..a743539c5361 100644 --- a/src/tirx/script/printer/function.cc +++ b/src/tirx/script/printer/function.cc @@ -65,7 +65,7 @@ int CountVarOccurrence(const tirx::PrimFunc& f, const tirx::Var& v) { } TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) - .set_dispatch("", [](tirx::PrimFunc func, ffi::reflection::AccessPath p, IRDocsifier d) -> Doc { + .set_dispatch("", [](tirx::PrimFunc func, AccessPath p, IRDocsifier d) -> Doc { With f(d, func); (*f)->AddDispatchToken(d, "tirx"); IdDoc func_name = IdDoc(FindFunctionName(d, func).value_or("main")); @@ -87,12 +87,12 @@ TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) std::unordered_set buffer_inlined; for (int i = 0; i < n_args; ++i) { tirx::Var var = func->params[i]; - ffi::reflection::AccessPath var_p = p->Attr("params")->ArrayItem(i); + AccessPath var_p = p->Attr("params")->ArrayItem(i); if (d->cfg->syntax_sugar && CountVarOccurrence(func, var) == 2 && func->buffer_map.count(var)) { tirx::Buffer buffer = func->buffer_map[var]; if (IsSimpleBuffer(buffer) && buffer_data_counter.at(buffer->data.get()) == 1) { - ffi::reflection::AccessPath buffer_p = p->Attr("buffer_map")->MapItem(var); + AccessPath buffer_p = p->Attr("buffer_map")->MapItem(var); IdDoc lhs = DefineBuffer(buffer, *f, d); ExprDoc annotation = BufferAttn(buffer, buffer_p, *f, d); args.push_back(AssignDoc(lhs, std::nullopt, annotation)); @@ -135,7 +135,7 @@ TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) continue; } ExprDoc param_doc = args[i]->lhs; - ffi::reflection::AccessPath buffer_p = p->Attr("buffer_map")->MapItem(param); + AccessPath buffer_p = p->Attr("buffer_map")->MapItem(param); ExprDoc lhs = DefineBuffer(buffer, *f, d); ExprDoc rhs = BufferDecl(buffer, "match_buffer", {param_doc}, buffer_p, *f, d, BufferVarDefinition::MatchBuffer); @@ -165,12 +165,12 @@ TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) }(); if (d->cfg->syntax_sugar && implicit_root_block) { tirx::SBlock root_block = implicit_root_block.value(); - ffi::reflection::AccessPath root_block_p = p->Attr("body")->Attr("block"); + AccessPath root_block_p = p->Attr("body")->Attr("block"); (*f)->stmts.push_back(CommentDoc("with T.sblock(\"root\"):")); // Handle root block `alloc_buffer` for (int i = 0, n = root_block->alloc_buffers.size(); i < n; ++i) { tirx::Buffer buffer = root_block->alloc_buffers[i]; - ffi::reflection::AccessPath buffer_p = root_block_p->Attr("alloc_buffers")->ArrayItem(i); + AccessPath buffer_p = root_block_p->Attr("alloc_buffers")->ArrayItem(i); IdDoc lhs = DefineBuffer(buffer, *f, d); ExprDoc rhs = BufferDecl(buffer, "sblock_alloc_buffer", {}, buffer_p, *f, d, BufferVarDefinition::DataPointer); @@ -193,7 +193,7 @@ TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) if (!func->attrs.defined() || !func->attrs->dict.count(tvm::attr::kGlobalSymbol)) { ffi::Array pos_args; decorator = decorator->Call(pos_args, {"private"}, - {LiteralDoc::Boolean(true, ffi::Optional())}); + {LiteralDoc::Boolean(true, ffi::Optional())}); } return HeaderWrapper(d, FunctionDoc( @@ -208,7 +208,7 @@ TVM_REGISTER_SCRIPT_AS_REPR(tirx::PrimFuncNode, ReprPrintTIR); TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) .set_dispatch( // - "tirx", [](tvm::GlobalVar n, ffi::reflection::AccessPath n_p, IRDocsifier d) -> Doc { // + "tirx", [](tvm::GlobalVar n, AccessPath n_p, IRDocsifier d) -> Doc { // if (ffi::Optional doc = d->GetVarDoc(n)) { return doc.value(); } else { @@ -220,7 +220,7 @@ TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) .set_dispatch( // - "tirx", [](tvm::IRModule mod, ffi::reflection::AccessPath n_p, IRDocsifier d) -> Doc { // + "tirx", [](tvm::IRModule mod, AccessPath n_p, IRDocsifier d) -> Doc { // ffi::Optional doc = d->GetVarDoc(mod); TVM_FFI_ICHECK(doc) << "Unable to print IRModule before definition in TIR."; return doc.value(); diff --git a/src/tirx/script/printer/ir.cc b/src/tirx/script/printer/ir.cc index 63af4bab0b4a..57bec5a56136 100644 --- a/src/tirx/script/printer/ir.cc +++ b/src/tirx/script/printer/ir.cc @@ -27,7 +27,7 @@ namespace printer { TVM_FFI_STATIC_INIT_BLOCK() { TIRFrameNode::RegisterReflection(); } TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) - .set_dispatch("", [](IntImm imm, ffi::reflection::AccessPath imm_p, IRDocsifier d) -> Doc { + .set_dispatch("", [](IntImm imm, AccessPath imm_p, IRDocsifier d) -> Doc { DataType dtype = imm->dtype; if (dtype == d->cfg->int_dtype) { return LiteralDoc::Int(imm->value, imm_p->Attr("value")); @@ -40,7 +40,7 @@ TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) }); TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) - .set_dispatch("", [](FloatImm imm, ffi::reflection::AccessPath imm_p, IRDocsifier d) -> Doc { + .set_dispatch("", [](FloatImm imm, AccessPath imm_p, IRDocsifier d) -> Doc { DataType dtype = imm->dtype; if (dtype == d->cfg->float_dtype) { return LiteralDoc::Float(imm->value, imm_p->Attr("value")); @@ -51,7 +51,7 @@ TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) }); TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) - .set_dispatch("tirx", [](Range range, ffi::reflection::AccessPath p, IRDocsifier d) -> Doc { + .set_dispatch("tirx", [](Range range, AccessPath p, IRDocsifier d) -> Doc { return TIR(d, "Range") ->Call({ d->AsDoc(range->min, p->Attr("min")), @@ -60,12 +60,12 @@ TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) }); TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) - .set_dispatch("", [](PrimType ty, ffi::reflection::AccessPath p, IRDocsifier d) -> Doc { + .set_dispatch("", [](PrimType ty, AccessPath p, IRDocsifier d) -> Doc { return TIR(d, DType2Str(ty->dtype)); }); TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) - .set_dispatch("", [](PointerType ty, ffi::reflection::AccessPath ty_p, IRDocsifier d) -> Doc { + .set_dispatch("", [](PointerType ty, AccessPath ty_p, IRDocsifier d) -> Doc { ExprDoc element_type{ffi::UnsafeInit()}; if (const auto* prim_type = ty->element_type.as()) { element_type = LiteralDoc::DataType(prim_type->dtype, // @@ -82,7 +82,7 @@ TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) }); TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) - .set_dispatch("", [](TupleType ty, ffi::reflection::AccessPath p, IRDocsifier d) -> Doc { + .set_dispatch("", [](TupleType ty, AccessPath p, IRDocsifier d) -> Doc { if (ty->fields.empty()) { return LiteralDoc::None(p); } @@ -90,7 +90,7 @@ TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) }); TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) - .set_dispatch("", [](Target target, ffi::reflection::AccessPath p, IRDocsifier d) -> Doc { + .set_dispatch("", [](Target target, AccessPath p, IRDocsifier d) -> Doc { ffi::Map config = target->ToConfig(); return TIR(d, "target")->Call({d->AsDoc(config, p)}); }); diff --git a/src/tirx/script/printer/stmt.cc b/src/tirx/script/printer/stmt.cc index 46dafdd62588..3c3ab21f9338 100644 --- a/src/tirx/script/printer/stmt.cc +++ b/src/tirx/script/printer/stmt.cc @@ -81,7 +81,7 @@ ffi::Optional FindReturnValue(const tirx::Stmt& node) { } TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) - .set_dispatch("", [](tirx::Evaluate eval, ffi::reflection::AccessPath p, IRDocsifier d) -> Doc { + .set_dispatch("", [](tirx::Evaluate eval, AccessPath p, IRDocsifier d) -> Doc { if (d->cfg->syntax_sugar) { if (auto return_value = FindReturnValue(eval)) { ExprDoc value = @@ -98,7 +98,7 @@ TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) }); TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) - .set_dispatch("", [](tirx::Bind stmt, ffi::reflection::AccessPath p, IRDocsifier d) -> Doc { + .set_dispatch("", [](tirx::Bind stmt, AccessPath p, IRDocsifier d) -> Doc { // Step 1. Type annotation ffi::Optional type_doc = d->AsDoc(stmt->var->type_annotation, // p->Attr("var")->Attr("type_annotation")); @@ -122,7 +122,7 @@ TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) .set_dispatch( - "", [](tirx::AssertStmt stmt, ffi::reflection::AccessPath p, IRDocsifier d) -> Doc { + "", [](tirx::AssertStmt stmt, AccessPath p, IRDocsifier d) -> Doc { ExprDoc cond = d->AsDoc(stmt->condition, p->Attr("condition")); // Always emit the canonical tuple form: assert cond, ("Kind", ["part0", "part1", ...]) ffi::Array parts; @@ -135,7 +135,7 @@ TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) }); TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) - .set_dispatch("", [](tirx::While stmt, ffi::reflection::AccessPath p, IRDocsifier d) -> Doc { + .set_dispatch("", [](tirx::While stmt, AccessPath p, IRDocsifier d) -> Doc { ExprDoc cond = d->AsDoc(stmt->condition, p->Attr("condition")); With f(d, stmt); AsDocBody(stmt->body, p->Attr("body"), f->get(), d); @@ -143,7 +143,7 @@ TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) }); namespace { -Doc DeclBufferDoc(tirx::DeclBuffer stmt, ffi::reflection::AccessPath p, IRDocsifier d, +Doc DeclBufferDoc(tirx::DeclBuffer stmt, AccessPath p, IRDocsifier d, BufferVarDefinition var_definitions) { ExprDoc rhs = BufferDecl(stmt->buffer, "decl_buffer", {}, p->Attr("buffer"), d->frames.back(), d, var_definitions); @@ -154,13 +154,13 @@ Doc DeclBufferDoc(tirx::DeclBuffer stmt, ffi::reflection::AccessPath p, IRDocsif TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) .set_dispatch( // - "", [](tirx::DeclBuffer stmt, ffi::reflection::AccessPath p, IRDocsifier d) -> Doc { + "", [](tirx::DeclBuffer stmt, AccessPath p, IRDocsifier d) -> Doc { return DeclBufferDoc(stmt, p, d, BufferVarDefinition::None); }); TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) .set_dispatch( // - "", [](tirx::IfThenElse stmt, ffi::reflection::AccessPath p, IRDocsifier d) -> Doc { + "", [](tirx::IfThenElse stmt, AccessPath p, IRDocsifier d) -> Doc { ExprDoc cond = d->AsDoc(stmt->condition, p->Attr("condition")); ffi::Array then_branch; ffi::Array else_branch; @@ -178,13 +178,13 @@ TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) }); TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) - .set_dispatch("", [](tirx::SeqStmt stmt, ffi::reflection::AccessPath p, IRDocsifier d) -> Doc { + .set_dispatch("", [](tirx::SeqStmt stmt, AccessPath p, IRDocsifier d) -> Doc { With f(d, stmt); AsDocBody(stmt, p, f->get(), d); return StmtBlockDoc((*f)->stmts); }); -void InsertEnvThread(const tirx::IterVar& iter_var, const ffi::reflection::AccessPath& iter_var_p, +void InsertEnvThread(const tirx::IterVar& iter_var, const AccessPath& iter_var_p, const IRDocsifier& d) { Frame f = FindLowestVarDef(iter_var->var, d).value(); DefineVar(iter_var->var, f, d); @@ -195,10 +195,10 @@ void InsertEnvThread(const tirx::IterVar& iter_var, const ffi::reflection::Acces f->stmts.push_back(AssignDoc(lhs, rhs, std::nullopt)); } -ExprDoc DocsifyLaunchThread(const tirx::AttrStmt& attr_stmt, const ffi::reflection::AccessPath& attr_stmt_p, +ExprDoc DocsifyLaunchThread(const tirx::AttrStmt& attr_stmt, const AccessPath& attr_stmt_p, ffi::Optional* define_var, const IRDocsifier& d) { tirx::IterVar iter_var = Downcast(attr_stmt->node); - ffi::reflection::AccessPath iter_var_p = attr_stmt_p->Attr("node"); + AccessPath iter_var_p = attr_stmt_p->Attr("node"); ExprDoc var_doc{ffi::UnsafeInit()}; if (d->IsVarDefined(iter_var->var)) { @@ -219,13 +219,13 @@ ExprDoc DocsifyLaunchThread(const tirx::AttrStmt& attr_stmt, const ffi::reflecti TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) .set_dispatch( // - "", [](tirx::AttrStmt stmt, ffi::reflection::AccessPath stmt_p, IRDocsifier d) -> Doc { + "", [](tirx::AttrStmt stmt, AccessPath stmt_p, IRDocsifier d) -> Doc { bool concise = AllowConciseScoping(d, stmt); ffi::Optional lhs = std::nullopt; ffi::Optional rhs = std::nullopt; ffi::Optional define_var = std::nullopt; tirx::Stmt body = stmt->body; - ffi::reflection::AccessPath body_p = stmt_p->Attr("body"); + AccessPath body_p = stmt_p->Attr("body"); if (stmt->attr_key == "thread_extent" || stmt->attr_key == "virtual_thread") { if (stmt->node.as()) { rhs = DocsifyLaunchThread(stmt, stmt_p, &define_var, d); @@ -252,9 +252,9 @@ TVM_REGISTER_SCRIPT_AS_REPR(tirx::AssertStmtNode, ReprPrintTIR); TVM_REGISTER_SCRIPT_AS_REPR(tirx::WhileNode, ReprPrintTIR); TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) .set_dispatch( // - "", [](tirx::AllocBuffer stmt, ffi::reflection::AccessPath p, IRDocsifier d) -> Doc { + "", [](tirx::AllocBuffer stmt, AccessPath p, IRDocsifier d) -> Doc { tirx::Buffer buffer = stmt->buffer; - ffi::reflection::AccessPath buffer_p = p->Attr("buffer"); + AccessPath buffer_p = p->Attr("buffer"); Frame frame = d->frames.back(); // Define buffer's data var inline as buffer.data if (!d->IsVarDefined(buffer->data)) { @@ -272,10 +272,10 @@ TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) int n = buffer->shape.size(); ffi::Array shape_docs; shape_docs.reserve(n); - ffi::reflection::AccessPath shape_p = buffer_p->Attr("shape"); + AccessPath shape_p = buffer_p->Attr("shape"); for (int i = 0; i < n; ++i) { PrimExpr e = buffer->shape[i]; - ffi::reflection::AccessPath e_p = shape_p->ArrayItem(i); + AccessPath e_p = shape_p->ArrayItem(i); if (!d->IsVarDefined(e) && e->IsInstance()) { ExprDoc lhs = DefineVar(Downcast(e), frame, d); lhs->source_paths.push_back(e_p); diff --git a/src/tirx/script/printer/utils.h b/src/tirx/script/printer/utils.h index 3207edf8bfe5..8dc6e703bccd 100644 --- a/src/tirx/script/printer/utils.h +++ b/src/tirx/script/printer/utils.h @@ -108,7 +108,7 @@ inline IdDoc DefineBuffer(const tirx::Buffer& buffer, const Frame& frame, const * \param f The frame * \param d The IRDocsifier */ -inline void AsDocBody(const tirx::Stmt& stmt, ffi::reflection::AccessPath p, TIRFrameNode* f, const IRDocsifier& d) { +inline void AsDocBody(const tirx::Stmt& stmt, AccessPath p, TIRFrameNode* f, const IRDocsifier& d) { if (const auto* seq_stmt = stmt.as()) { ffi::Array body = seq_stmt->seq; for (int i = 0, n = body.size(); i < n; ++i) { @@ -214,7 +214,7 @@ enum class BufferVarDefinition { * \return The ExprDoc corresponding to the buffer declaration */ ExprDoc BufferDecl(const tirx::Buffer& buffer, const ffi::String& method, - const ffi::Array& args, const ffi::reflection::AccessPath& p, const Frame& frame, + const ffi::Array& args, const AccessPath& p, const Frame& frame, const IRDocsifier& d, BufferVarDefinition var_definitions); /*! @@ -225,7 +225,7 @@ ExprDoc BufferDecl(const tirx::Buffer& buffer, const ffi::String& method, * \param d The IRDocsifier * \return The ExprDoc corresponding to the buffer declaration */ -ExprDoc BufferAttn(const tirx::Buffer& buffer, const ffi::reflection::AccessPath& p, const Frame& frame, +ExprDoc BufferAttn(const tirx::Buffer& buffer, const AccessPath& p, const Frame& frame, const IRDocsifier& d); /*! @@ -235,7 +235,7 @@ ExprDoc BufferAttn(const tirx::Buffer& buffer, const ffi::reflection::AccessPath * \param d The IRDocsifier * \return The ExprDoc corresponding to the Var creation */ -ExprDoc PrintVarCreation(const tirx::Var& var, const ffi::reflection::AccessPath& var_p, const IRDocsifier& d); +ExprDoc PrintVarCreation(const tirx::Var& var, const AccessPath& var_p, const IRDocsifier& d); /*! \brief A Var occurrence counter visitor */ class OccurrenceCounter : public tirx::StmtExprVisitor { diff --git a/src/tirx/transform/ir_utils.cc b/src/tirx/transform/ir_utils.cc index 96b0209415ba..9130bca9c091 100644 --- a/src/tirx/transform/ir_utils.cc +++ b/src/tirx/transform/ir_utils.cc @@ -682,13 +682,13 @@ ffi::Array ConvertIndices(const MatchBufferRegion& match_buffer, return result; } -ffi::Array ConvertRegion(const MatchBufferRegion& match_buffer, const ffi::Array& region) { +Region ConvertRegion(const MatchBufferRegion& match_buffer, const Region& region) { const Buffer& target = match_buffer->buffer; const BufferRegion& source = match_buffer->source; TVM_FFI_ICHECK_EQ(region.size(), target->shape.size()); arith::Analyzer analyzer; - ffi::Array result; + Region result; result.reserve(source->region.size()); size_t offset = source->region.size() - region.size(); for (size_t i = 0; i < offset; ++i) { diff --git a/src/tirx/transform/ir_utils.h b/src/tirx/transform/ir_utils.h index 6427ae43a2e2..f77d73fbcff0 100644 --- a/src/tirx/transform/ir_utils.h +++ b/src/tirx/transform/ir_utils.h @@ -227,7 +227,7 @@ ffi::Array ConvertIndices(const MatchBufferRegion& match_buffer, * \param region The sub-region of the target buffer * \return The region of source buffer. */ -ffi::Array ConvertRegion(const MatchBufferRegion& match_buffer, const ffi::Array& region); +Region ConvertRegion(const MatchBufferRegion& match_buffer, const Region& region); /*! * \brief Get stride aware buffer allocation shape from buffer.