Skip to content
Merged
213 changes: 88 additions & 125 deletions include/tvm/ir/attrs.h
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
* This module enables declaration of named attributes
* which support default value setup and bound checking.
*
* \sa BaseAttrsNode, AttrsWithDefaultValues
* \sa AttrsNode
*/
#ifndef TVM_IR_ATTRS_H_
#define TVM_IR_ATTRS_H_
Expand All @@ -43,59 +43,23 @@

namespace tvm {

/*!
* \brief Information about attribute fields in string representations.
*/
class AttrFieldInfoNode : public ffi::Object {
public:
/*! \brief name of the field */
ffi::String name;
/*! \brief type docstring information in str. */
ffi::String type_info;
/*! \brief detailed description of the type */
ffi::String description;

static void RegisterReflection() {
namespace rfl = ffi::reflection;
rfl::ObjectDef<AttrFieldInfoNode>()
.def_ro("name", &AttrFieldInfoNode::name)
.def_ro("type_info", &AttrFieldInfoNode::type_info)
.def_ro("description", &AttrFieldInfoNode::description);
}

static constexpr TVMFFISEqHashKind _type_s_eq_hash_kind = kTVMFFISEqHashKindTreeNode;

TVM_FFI_DECLARE_OBJECT_INFO_FINAL("ir.AttrFieldInfo", AttrFieldInfoNode, ffi::Object);
};

/*! \brief AttrFieldInfo */
class AttrFieldInfo : public ffi::ObjectRef {
public:
TVM_FFI_DEFINE_OBJECT_REF_METHODS_NULLABLE(AttrFieldInfo, ffi::ObjectRef, AttrFieldInfoNode);
};

/*!
* \brief Base class of all attribute class
* \note Do not subclass AttrBaseNode directly,
* subclass AttrsNode instead.
* \sa AttrsNode
* \sa Attrs
*/
class BaseAttrsNode : public ffi::Object {
class AttrsNode : public ffi::Object {
public:
/*! \brief virtual destructor */
virtual ~BaseAttrsNode() {}

static constexpr TVMFFISEqHashKind _type_s_eq_hash_kind = kTVMFFISEqHashKindTreeNode;
TVM_FFI_DECLARE_OBJECT_INFO("ir.Attrs", BaseAttrsNode, ffi::Object);
TVM_FFI_DECLARE_OBJECT_INFO("ir.Attrs", AttrsNode, ffi::Object);
};

/*!
* \brief Managed reference to BaseAttrsNode.
* \sa AttrsNode, BaseAttrsNode
* \brief Managed reference to AttrsNode.
* \sa AttrsNode
*/
class Attrs : public ffi::ObjectRef {
public:
TVM_FFI_DEFINE_OBJECT_REF_METHODS_NULLABLE(Attrs, ffi::ObjectRef, BaseAttrsNode);
TVM_FFI_DEFINE_OBJECT_REF_METHODS_NULLABLE(Attrs, ffi::ObjectRef, AttrsNode);
};

/*!
Expand All @@ -104,7 +68,7 @@ class Attrs : public ffi::ObjectRef {
* its fields are directly accessible via object.field_name
* like other normal nodes.
*/
class DictAttrsNode : public BaseAttrsNode {
class DictAttrsNode : public AttrsNode {
public:
/*! \brief internal attrs map */
ffi::Map<ffi::String, ffi::Any> dict;
Expand All @@ -115,28 +79,70 @@ class DictAttrsNode : public BaseAttrsNode {
}

// type info
TVM_FFI_DECLARE_OBJECT_INFO_FINAL("ir.DictAttrs", DictAttrsNode, BaseAttrsNode);
TVM_FFI_DECLARE_OBJECT_INFO_FINAL("ir.DictAttrs", DictAttrsNode, AttrsNode);
};

/*!
* \brief Managed reference to DictAttrsNode
* \sa DictAttrsNode.
*
* \note DictAttrs is NOTNULLABLE: every instance must hold a backing
* DictAttrsNode. The class enforces this end-to-end by:
* - the default constructor (no args) allocating an empty backing,
* - the copy/move ctors and assignments leaving the moved-from
* instance in a defined-but-empty state rather than null,
* - the FFI type traits rejecting None at deserialization boundaries
* (since `_type_is_nullable == false`), and
* - the FFI lambda for ``ir.IRModule`` explicitly normalizing a
* missing/None attrs argument to ``DictAttrs()`` before forwarding
* to the C++ constructor.
* Callers (including third-party code via templates like ``WithAttr``)
* can therefore rely on ``attrs->dict`` being safe to dereference
* without a ``.defined()`` guard.
*/
class DictAttrs : public Attrs {
public:
/*!
* \brief constructor with UnsafeInit
* \brief Construct a DictAttrs backed by DictAttrsNode.
*
* The no-argument form constructs an empty (but always defined) DictAttrs.
* \param dict The attributes.
*/
explicit DictAttrs(ffi::Map<ffi::String, Any> dict = {}) {
ffi::ObjectPtr<DictAttrsNode> n = ffi::make_object<DictAttrsNode>();
n->dict = std::move(dict);
data_ = std::move(n);
}

/*!
* \brief Move constructor that leaves the source in a defined-but-empty
* state rather than null, preserving the NOTNULLABLE invariant
* even after `std::move`.
*/
explicit DictAttrs(ffi::UnsafeInit tag) : Attrs(tag) {}
DictAttrs(DictAttrs&& other) noexcept : Attrs(ffi::UnsafeInit{}) {
data_ = std::move(other.data_);
other.data_ = ffi::make_object<DictAttrsNode>();
}

/*!
* \brief Consruct a Attrs backed by DictAttrsNode.
* \param dict The attributes.
* \brief Move assignment that leaves the source in a defined-but-empty
* state rather than null, preserving the NOTNULLABLE invariant
* even after `std::move`.
*/
TVM_DLL explicit DictAttrs(ffi::Map<ffi::String, Any> dict = {});
DictAttrs& operator=(DictAttrs&& other) noexcept {
if (this != &other) {
data_ = std::move(other.data_);
other.data_ = ffi::make_object<DictAttrsNode>();
}
return *this;
}

// Explicit copy ctor/assign defaults. Declaring the move members above
// would otherwise suppress the implicit copy members.
DictAttrs(const DictAttrs& other) = default;
DictAttrs& operator=(const DictAttrs& other) = default;

// Utils for accessing attributes
// This needs to be on DictAttrs, not DictAttrsNode because we return the default
// value if DictAttrsNode is not defined.
/*!
* \brief Get a function attribute.
*
Expand All @@ -160,8 +166,7 @@ class DictAttrs : public Attrs {
ffi::Optional<TObjectRef> GetAttr(
const std::string& attr_key,
ffi::Optional<TObjectRef> default_value = ffi::Optional<TObjectRef>(std::nullopt)) const {
if (!defined()) return default_value;
const DictAttrsNode* node = this->as<DictAttrsNode>();
const DictAttrsNode* node = get();
Comment thread
tqchen marked this conversation as resolved.
auto it = node->dict.find(attr_key);
if (it != node->dict.end()) {
return (*it).second.cast<TObjectRef>();
Expand Down Expand Up @@ -197,57 +202,19 @@ class DictAttrs : public Attrs {
return GetAttr<int64_t>(attr_key, 0).value_or(0) != 0;
}

explicit DictAttrs(::tvm::ffi::ObjectPtr<DictAttrsNode> n) : Attrs(n) {}
DictAttrs(const DictAttrs&) = default;
DictAttrs(DictAttrs&&) = default;
DictAttrs& operator=(const DictAttrs&) = default;
DictAttrs& operator=(DictAttrs&&) = default;
const DictAttrsNode* operator->() const { return static_cast<const DictAttrsNode*>(data_.get()); }
const DictAttrsNode* get() const { return operator->(); }
// Inline-expand TVM_FFI_DEFINE_OBJECT_REF_METHODS_NOTNULLABLE here, minus
// the default copy/move it normally injects (we define our own move members
// above so the moved-from instance stays defined-but-empty).
explicit DictAttrs(::tvm::ffi::UnsafeInit tag) : Attrs(tag) {}
using __PtrType =
std::conditional_t<DictAttrsNode::_type_mutable, DictAttrsNode*, const DictAttrsNode*>;
__PtrType operator->() const { return static_cast<__PtrType>(data_.get()); }
__PtrType get() const { return static_cast<__PtrType>(data_.get()); }
static constexpr bool _type_is_nullable = false;
using ContainerType = DictAttrsNode;
TVM_DEFINE_OBJECT_REF_COW_METHOD(DictAttrsNode);
};

/*!
* \brief Copy the DictAttrs, but overrides attributes with the
* entries from \p attrs.
*
* \param attrs The DictAttrs to update
*
* \param new_attrs Key/values attributes to add to \p attrs.
*
* \returns The new DictAttrs with updated attributes.
*/
DictAttrs WithAttrs(DictAttrs attrs, ffi::Map<ffi::String, Any> new_attrs);

/*!
* \brief Copy the DictAttrs, but overrides a single attribute.
*
* \param attrs The DictAttrs to update
*
* \param key The update to insert or update.
*
* \param value The new value of the attribute
*
* \returns The new DictAttrs with updated attributes.
*/
DictAttrs WithAttr(DictAttrs attrs, ffi::String key, Any value);

inline DictAttrs WithAttr(DictAttrs attrs, const std::string& key, Any value) {
return WithAttr(std::move(attrs), ffi::String(key), std::move(value));
}

/*!
* \brief Copy the DictAttrs, but without a specific attribute.
*
* \param attrs The DictAttrs to update
*
* \param key The key to remove
*
* \returns The new DictAttrs with updated attributes.
*/
DictAttrs WithoutAttr(DictAttrs attrs, const std::string& key);

/*!
* \brief Copy the function or module, but overrides
* the attribute value key with the value.
Expand Down Expand Up @@ -280,7 +247,10 @@ inline TFunc WithAttr(TFunc input, const std::string& attr_key, Any attr_value)
using TNode = typename TFunc::ContainerType;
static_assert(TNode::_type_final, "Can only operate on the leaf nodes");
TNode* node = input.CopyOnWrite();
node->attrs = WithAttr(std::move(node->attrs), attr_key, attr_value);
// node->attrs is NOTNULLABLE by contract, but defend against a caller
// that left a moved-from DictAttrs in place by re-initializing here.
if (!node->attrs.defined()) node->attrs = DictAttrs();
node->attrs.CopyOnWrite()->dict.Set(attr_key, std::move(attr_value));

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

If node->attrs is null (which can happen if the function was deserialized from a format where attributes were omitted), calling CopyOnWrite() on it will cause a segmentation fault. We should defensively initialize node->attrs if it is not defined.

  if (!node->attrs.defined()) {
    node->attrs = DictAttrs();
  }
  node->attrs.CopyOnWrite()->dict.Set(attr_key, std::move(attr_value));

return input;
}

Expand All @@ -298,10 +268,15 @@ template <typename TFunc>
inline TFunc WithAttrs(TFunc input, ffi::Map<ffi::String, Any> attrs) {
using TNode = typename TFunc::ContainerType;
static_assert(TNode::_type_final, "Can only operate on the leaf nodes");
if (attrs.empty()) return input;
TNode* node = input.CopyOnWrite();

node->attrs = WithAttrs(std::move(node->attrs), attrs);

// node->attrs is NOTNULLABLE by contract, but defend against a caller
// that left a moved-from DictAttrs in place by re-initializing here.
if (!node->attrs.defined()) node->attrs = DictAttrs();
auto* dict_node = node->attrs.CopyOnWrite();

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

If node->attrs is null, calling CopyOnWrite() on it will cause a segmentation fault. We should defensively initialize node->attrs if it is not defined.

Suggested change
auto* dict_node = node->attrs.CopyOnWrite();
if (!node->attrs.defined()) {
node->attrs = DictAttrs();
}
auto* dict_node = node->attrs.CopyOnWrite();

for (const auto& [k, v] : attrs) {
dict_node->dict.Set(k, v);
}
return input;
}

Expand Down Expand Up @@ -335,29 +310,17 @@ template <typename TFunc>
inline TFunc WithoutAttr(TFunc input, const std::string& attr_key) {
using TNode = typename TFunc::ContainerType;
static_assert(TNode::_type_final, "Can only operate on the leaf nodes");

TNode* node = input.CopyOnWrite();
node->attrs = WithoutAttr(std::move(node->attrs), attr_key);

// node->attrs is NOTNULLABLE by contract, but defend against a caller
// that left a moved-from DictAttrs in place; nothing to erase from an
// empty dict.
if (!node->attrs.defined()) {
node->attrs = DictAttrs();
return input;
}
node->attrs.CopyOnWrite()->dict.erase(attr_key);

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

If node->attrs is null, calling CopyOnWrite() on it will cause a segmentation fault. We should check if node->attrs is defined before attempting to erase from it.

  if (node->attrs.defined()) {
    node->attrs.CopyOnWrite()->dict.erase(attr_key);
  }

return input;
}

/*!
* \brief Create an object with all default values, using the reflection defaults.
* \tparam TObj the ObjectRef type to be created.
* \return An instance with all reflection-defined default values applied.
*/
template <typename TObj>
inline TObj AttrsWithDefaultValues() {
static_assert(std::is_base_of_v<ffi::ObjectRef, TObj>, "Can only create ObjectRef-derived types");
using ContainerType = typename TObj::ContainerType;
static auto finit_object = ffi::Function::GetGlobalRequired("ffi.MakeObjectFromPackedArgs");
AnyView packed_args[1];
packed_args[0] = ContainerType::RuntimeTypeIndex();
ffi::Any rv;
finit_object.CallPacked(ffi::PackedArgs(packed_args, 1), &rv);
return rv.cast<TObj>();
}

} // namespace tvm
#endif // TVM_IR_ATTRS_H_
41 changes: 38 additions & 3 deletions include/tvm/ir/op.h
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,41 @@ namespace tvm {
template <typename>
class OpAttrMap;

/*!
* \brief Information about an input field of an Op (name, type, description).
*
* Populated via OpRegEntry::add_argument and consumed both by
* internal sanity checks / error messages and by external tooling
* that wants to introspect an Op's argument schema.
*/
class ArgumentInfoNode : public ffi::Object {
public:
/*! \brief name of the field */
ffi::String name;
/*! \brief type docstring information in str. */
ffi::String type_info;
/*! \brief detailed description of the type */
ffi::String description;

static void RegisterReflection() {
namespace rfl = ffi::reflection;
rfl::ObjectDef<ArgumentInfoNode>()
.def_ro("name", &ArgumentInfoNode::name)
.def_ro("type_info", &ArgumentInfoNode::type_info)
.def_ro("description", &ArgumentInfoNode::description);
}

static constexpr TVMFFISEqHashKind _type_s_eq_hash_kind = kTVMFFISEqHashKindTreeNode;

TVM_FFI_DECLARE_OBJECT_INFO_FINAL("ir.ArgumentInfo", ArgumentInfoNode, ffi::Object);
};

/*! \brief Managed reference to ArgumentInfoNode. */
class ArgumentInfo : public ffi::ObjectRef {
public:
TVM_FFI_DEFINE_OBJECT_REF_METHODS_NULLABLE(ArgumentInfo, ffi::ObjectRef, ArgumentInfoNode);
};

// TODO(tvm-team): migrate low-level intrinsics to use Op
/*!
* \brief Primitive Op(builtin intrinsics)
Expand All @@ -68,7 +103,7 @@ class OpNode : public RelaxExprNode {
*/
ffi::String description;
/* \brief Information of input arguments to the operator */
ffi::Array<AttrFieldInfo> arguments;
ffi::Array<ArgumentInfo> arguments;
/*!
* \brief The type key of the attribute field
* This can be empty, in which case it defaults to anything.
Expand Down Expand Up @@ -330,11 +365,11 @@ inline OpRegEntry& OpRegEntry::describe(const std::string& descr) { // NOLINT(*

inline OpRegEntry& OpRegEntry::add_argument(const std::string& name, const std::string& type,
const std::string& description) {
auto n = ffi::make_object<AttrFieldInfoNode>();
auto n = ffi::make_object<ArgumentInfoNode>();
n->name = name;
n->type_info = type;
n->description = description;
get()->arguments.push_back(AttrFieldInfo(n));
get()->arguments.push_back(ArgumentInfo(n));
return *this;
}

Expand Down
Loading
Loading