Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 5 additions & 3 deletions include/tvm/ffi/enum.h
Original file line number Diff line number Diff line change
Expand Up @@ -131,16 +131,18 @@ inline Enum EnumObj::Get(const String& name) {
static_assert(std::is_base_of_v<EnumObj, EnumClsObj>,
"EnumObj::Get<T> requires T to be a subclass of EnumObj");
const TVMFFITypeAttrColumn* column = GetEnumEntriesColumn();
int32_t type_index = EnumClsObj::RuntimeTypeIndex();
int32_t type_index = EnumClsObj::_GetOrAllocRuntimeTypeIndex();
if (column != nullptr) {
int32_t offset = type_index - column->begin_index;
if (offset >= 0 && offset < column->size) {
const TVMFFIAny* stored = &column->data[offset];
if (stored->type_index != kTVMFFINone) {
Dict<String, Enum> entries = AnyView::CopyFromTVMFFIAny(*stored).cast<Dict<String, Enum>>();
Dict<String, ObjectRef> entries =
AnyView::CopyFromTVMFFIAny(*stored).cast<Dict<String, ObjectRef>>();
auto it = entries.find(name);
if (it != entries.end()) {
return (*it).second;
return details::ObjectUnsafe::ObjectRefFromObjectPtr<Enum>(
details::ObjectUnsafe::ObjectPtrFromObjectRef<EnumObj>((*it).second));
}
}
}
Expand Down
13 changes: 7 additions & 6 deletions include/tvm/ffi/object.h
Original file line number Diff line number Diff line change
Expand Up @@ -307,11 +307,11 @@ class Object {
* \brief Get the runtime allocated type index of the type
* \note Getting this information may need dynamic calls into a global table.
*/
static int32_t RuntimeTypeIndex() { return TypeIndex::kTVMFFIObject; }
TVM_FFI_INLINE static int32_t RuntimeTypeIndex() noexcept { return TypeIndex::kTVMFFIObject; }
/*!
* \brief Internal function to get or allocate a runtime index.
*/
static int32_t _GetOrAllocRuntimeTypeIndex() { // NOLINT(bugprone-reserved-identifier)
TVM_FFI_COLD_CODE static int32_t _GetOrAllocRuntimeTypeIndex() { // NOLINT(*)
return TypeIndex::kTVMFFIObject;
}

Expand Down Expand Up @@ -936,7 +936,7 @@ struct ObjectPtrEqual {
*/
#define TVM_FFI_DECLARE_OBJECT_INFO_STATIC(TypeKey, TypeName, ParentType) \
static constexpr int32_t _type_depth = ParentType::_type_depth + 1; \
static int32_t _GetOrAllocRuntimeTypeIndex() { \
TVM_FFI_COLD_CODE static int32_t _GetOrAllocRuntimeTypeIndex() { \
static_assert(!ParentType::_type_final, "ParentType marked as final"); \
static_assert(TypeName::_type_child_slots == 0 || ParentType::_type_child_slots == 0 || \
TypeName::_type_child_slots < ParentType::_type_child_slots, \
Expand All @@ -948,7 +948,7 @@ struct ObjectPtrEqual {
TypeName::_type_child_slots_can_overflow, ParentType::_GetOrAllocRuntimeTypeIndex()); \
return TypeName::_type_index; \
} \
static int32_t RuntimeTypeIndex() { return TypeName::_type_index; } \
TVM_FFI_INLINE static int32_t RuntimeTypeIndex() noexcept { return TypeName::_type_index; } \
static constexpr const char* _type_key = TypeKey

/*!
Expand All @@ -959,7 +959,7 @@ struct ObjectPtrEqual {
*/
#define TVM_FFI_DECLARE_OBJECT_INFO_PREDEFINED_TYPE_KEY(TypeName, ParentType) \
static constexpr int32_t _type_depth = ParentType::_type_depth + 1; \
static int32_t _GetOrAllocRuntimeTypeIndex() { \
TVM_FFI_COLD_CODE static int32_t _GetOrAllocRuntimeTypeIndex() { \
static_assert(!ParentType::_type_final, "ParentType marked as final"); \
static_assert(TypeName::_type_child_slots == 0 || ParentType::_type_child_slots == 0 || \
TypeName::_type_child_slots < ParentType::_type_child_slots, \
Expand All @@ -971,7 +971,8 @@ struct ObjectPtrEqual {
TypeName::_type_child_slots_can_overflow, ParentType::_GetOrAllocRuntimeTypeIndex()); \
return tindex; \
} \
static int32_t RuntimeTypeIndex() { return _GetOrAllocRuntimeTypeIndex(); }
static inline const int32_t _type_index = TypeName::_GetOrAllocRuntimeTypeIndex(); \
TVM_FFI_INLINE static int32_t RuntimeTypeIndex() noexcept { return TypeName::_type_index; }

/*!
* \brief Helper macro to declare object information with dynamic type index.
Expand Down
9 changes: 5 additions & 4 deletions include/tvm/ffi/reflection/enum_def.h
Original file line number Diff line number Diff line change
Expand Up @@ -74,15 +74,16 @@ class EnumDef : public ReflectionDefBase {
* \param instance_name The instance's string name (e.g., ``"Add"``).
*/
explicit EnumDef(const char* instance_name)
: type_index_(EnumClsObj::RuntimeTypeIndex()), name_(instance_name) {
Dict<String, Enum> entries = EnsureEntriesDict();
: type_index_(EnumClsObj::_GetOrAllocRuntimeTypeIndex()), name_(instance_name) {
Dict<String, ObjectRef> entries = EnsureEntriesDict();
String name_str(name_);
if (entries.count(name_str) != 0) {
TVM_FFI_THROW(RuntimeError) << "Duplicate enum entry `" << name_ << "` for type `"
<< EnumClsObj::_type_key << "`";
}
ordinal_ = static_cast<int64_t>(entries.size());
ObjectPtr<EnumClsObj> obj = make_object<EnumClsObj>();
::tvm::ffi::details::ObjectUnsafe::GetHeader(obj.get())->type_index = type_index_;
obj->_value = ordinal_;
obj->_name = name_str;
instance_ = Enum(ObjectPtr<EnumObj>(std::move(obj)));
Expand Down Expand Up @@ -134,8 +135,8 @@ class EnumDef : public ReflectionDefBase {
int64_t ordinal() const { return ordinal_; }

private:
Dict<String, Enum> EnsureEntriesDict() {
return EnsureDict<Dict<String, Enum>>(type_attr::kEnumEntries);
Dict<String, ObjectRef> EnsureEntriesDict() {
return EnsureDict<Dict<String, ObjectRef>>(type_attr::kEnumEntries);
}

Dict<String, List<Any>> EnsureAttrsDict() {
Expand Down
2 changes: 1 addition & 1 deletion include/tvm/ffi/reflection/registry.h
Original file line number Diff line number Diff line change
Expand Up @@ -992,7 +992,7 @@ class TypeAttrDef : public ReflectionDefBase {
*/
template <typename... ExtraArgs>
explicit TypeAttrDef(ExtraArgs&&... extra_args)
: type_index_(Class::RuntimeTypeIndex()), type_key_(Class::_type_key) {}
: type_index_(Class::_GetOrAllocRuntimeTypeIndex()), type_key_(Class::_type_key) {}

/*!
* \brief Define a function-valued type attribute.
Expand Down
Loading