diff --git a/include/tvm/ir/expr.h b/include/tvm/ir/expr.h index 4a00de802c61..b54a067e1c94 100644 --- a/include/tvm/ir/expr.h +++ b/include/tvm/ir/expr.h @@ -260,9 +260,10 @@ class GlobalVarNode : public RelayExprNode { */ class GlobalVar : public RelayExpr { public: - TVM_DLL explicit GlobalVar(String name_hint, Type type = {}); + TVM_DLL explicit GlobalVar(String name_hint, Type type = {}, Span span = {}); TVM_DEFINE_OBJECT_REF_METHODS(GlobalVar, RelayExpr, GlobalVarNode); + TVM_DEFINE_OBJECT_REF_COW_METHOD(GlobalVarNode); }; // PrimExprs that are useful as runtime containers. diff --git a/include/tvm/relay/expr.h b/include/tvm/relay/expr.h index fe570806922f..4d25164f314d 100644 --- a/include/tvm/relay/expr.h +++ b/include/tvm/relay/expr.h @@ -39,6 +39,22 @@ #include "./type.h" namespace tvm { + +/*! + * \brief Returns the global_var with given properties. A null property denotes 'no change'. + * Returns this if all properties are unchanged. Otherwise, returns a copy with the new fields. + * \param global_var The global var to copy + * \param opt_name_hint The (optional) op name of the global var + * \param opt_type The (optional) type for the global var + * \param opt_virtual_device The (optional) virtual_device for the copied constant. If none, + * ret_constant->virtual_device = constant->virtual_device. + * \param opt_span The (optional) span for the copied global var. If none, + * ret_constant->span = constant->span. + */ +GlobalVar WithFields(GlobalVar global_var, Optional opt_name_hint = {}, + Optional opt_type = {}, Optional opt_virtual_device = {}, + Optional opt_span = {}); + namespace relay { using Expr = tvm::RelayExpr; @@ -97,8 +113,23 @@ class Constant : public Expr { TVM_DLL explicit Constant(runtime::NDArray data, Span span = Span()); TVM_DEFINE_OBJECT_REF_METHODS(Constant, RelayExpr, ConstantNode); + TVM_DEFINE_OBJECT_REF_COW_METHOD(ConstantNode); }; +/*! + * \brief Returns the constant with given properties. A null property denotes 'no change'. + * Returns this if all properties are unchanged. Otherwise, returns a copy with the new fields. + * \param constant The constant to copy + * \param opt_data The (optional) data for the copied constant. If none, ret_constant->data = + * constant->data. + * \param opt_virtual_device The (optional) virtual_device for the copied constant. If none, + * ret_constant->virtual_device = constant->virtual_device. + * \param opt_span The (optional) span for the copied constant. If none, + * ret_constant->span = constant->span. + */ +Constant WithFields(Constant constant, Optional opt_data = {}, + Optional opt_virtual_device = {}, Optional opt_span = {}); + /*! \brief Tuple of multiple Exprs */ class Tuple; /*! \brief Tuple container */ diff --git a/src/ir/expr.cc b/src/ir/expr.cc index 399873492f04..a3318bf94fc6 100644 --- a/src/ir/expr.cc +++ b/src/ir/expr.cc @@ -141,10 +141,11 @@ TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) p->stream << "range(min=" << op->min << ", ext=" << op->extent << ')'; }); -GlobalVar::GlobalVar(String name_hint, Type type) { +GlobalVar::GlobalVar(String name_hint, Type type, Span span) { ObjectPtr n = make_object(); n->name_hint = std::move(name_hint); n->checked_type_ = std::move(type); + n->span = std::move(span); data_ = std::move(n); } diff --git a/src/relay/ir/expr.cc b/src/relay/ir/expr.cc index fc76577bd7c0..bb32c17a2246 100644 --- a/src/relay/ir/expr.cc +++ b/src/relay/ir/expr.cc @@ -27,6 +27,26 @@ namespace tvm { +GlobalVar WithFields(GlobalVar global_var, Optional opt_name_hint, Optional opt_type, + Optional opt_virtual_device, Optional opt_span) { + String name_hint = opt_name_hint.value_or(global_var->name_hint); + Type type = opt_type.value_or(global_var->checked_type()); + VirtualDevice virtual_device = opt_virtual_device.value_or(global_var->virtual_device()); + Span span = opt_span.value_or(global_var->span); + bool all_fields_unchanged = + name_hint.same_as(global_var->name_hint) && type.same_as(global_var->checked_type()) && + virtual_device.same_as(global_var->virtual_device()) && span.same_as(global_var->span); + if (!all_fields_unchanged) { + GlobalVarNode* cow_global_var_node = global_var.CopyOnWrite(); + cow_global_var_node->name_hint = name_hint; + cow_global_var_node->checked_type_ = type; + cow_global_var_node->virtual_device_ = virtual_device; + cow_global_var_node->span = span; + } + + return global_var; +} + VirtualDevice RelayExprNode::virtual_device() const { if (!this->virtual_device_.defined()) { // virtual_device_ should always be defined, unless we imported this node from JSON using an old @@ -77,6 +97,25 @@ TensorType ConstantNode::tensor_type() const { return TensorType(shape, dtype); } +Constant WithFields(Constant constant, Optional opt_data, + Optional opt_virtual_device, Optional opt_span) { + runtime::NDArray data = opt_data.value_or(constant->data); + VirtualDevice virtual_device = opt_virtual_device.value_or(constant->virtual_device()); + Span span = opt_span.value_or(constant->span); + + bool all_fields_unchanged = data.same_as(constant->data) && + virtual_device.same_as(constant->virtual_device()) && + span.same_as(constant->span); + + if (!all_fields_unchanged) { + ConstantNode* cow_constant_node = constant.CopyOnWrite(); + cow_constant_node->data = data; + cow_constant_node->virtual_device_ = virtual_device; + cow_constant_node->span = span; + } + return constant; +} + Tuple::Tuple(tvm::Array fields, Span span) { ObjectPtr n = make_object(); n->fields = std::move(fields); diff --git a/tests/cpp/withfields_test.cc b/tests/cpp/withfields_test.cc new file mode 100644 index 000000000000..73f5d3c920b5 --- /dev/null +++ b/tests/cpp/withfields_test.cc @@ -0,0 +1,53 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +#include +#include +#include +#include +#include +#include + +namespace tvm { +namespace relay { +namespace { + +TEST(WithFields, GlobalVar) { + auto tensor_type = relay::TensorType({}, DataType::Bool()); + GlobalVar func_init("dummy_func", tensor_type, {}); + GlobalVar func_cp = WithFields(func_init); + ICHECK(func_init->name_hint == func_cp->name_hint); + ICHECK(func_init->span == func_cp->span); + ICHECK(func_init->checked_type_ == func_cp->checked_type_); +} + +TEST(WithFields, Constant) { + int64_t out_channels = 64; + Device dev{DLDeviceType::kDLCPU, 0}; + runtime::NDArray multiplier_nda = runtime::NDArray::Empty({out_channels}, DataType::Int(32), dev); + Constant constant_init(multiplier_nda, {}); + Constant constant_cp = WithFields(constant_init); + ICHECK(constant_init->checked_type_ == constant_cp->checked_type_); + ICHECK_EQ(constant_init->data, constant_cp->data); + ICHECK_EQ(constant_init->span, constant_cp->span); +} + +} // namespace +} // namespace relay +} // namespace tvm