diff --git a/include/tvm/relay/interpreter.h b/include/tvm/relay/interpreter.h index 8a41ab74658f..36ee63058c3c 100644 --- a/include/tvm/relay/interpreter.h +++ b/include/tvm/relay/interpreter.h @@ -36,6 +36,7 @@ #include #include +#include #include #include #include diff --git a/include/tvm/runtime/vm/bytecode.h b/include/tvm/runtime/vm/bytecode.h index e858c4458054..b18f762438e0 100644 --- a/include/tvm/runtime/vm/bytecode.h +++ b/include/tvm/runtime/vm/bytecode.h @@ -68,6 +68,9 @@ enum class Opcode { ShapeOf = 17U, ReshapeTensor = 18U, DeviceCopy = 19U, + RefCreate = 20U, + RefRead = 21U, + RefWrite = 22U, }; /*! \brief A single virtual machine instruction. @@ -215,6 +218,16 @@ struct Instruction { /*! \brief The destination device type. */ Index dst_device_type; }; + struct /* RefCreate Operands */ { + RegName initial_value; + } ref_create; + struct /* RefRead */ { + RegName ref; + } ref_read; + struct /* RefWrite */ { + RegName ref; + RegName value; + } ref_write; }; /*! @@ -384,6 +397,31 @@ struct Instruction { static Instruction DeviceCopy(RegName src, Index src_device_type, Index dst_device_type, RegName dst); + /*! + * \brief Create a reference. + * \param value The register corresponding to the initial value of the reference. + * \param dst The destination register. + * \return The reference creation instruction. + */ + static Instruction RefCreate(RegName value, RegName dst); + + /*! + * \brief Read a value from a reference. + * \param ref The register to read from. + * \param dst The destination register. + * \return The reference read instruction. + */ + static Instruction RefRead(RegName ref, RegName dst); + + /*! + * \brief Write a value to the reference. + * \param ref The register to write to. + * \param value The value to write to the register. + * \param dst The destination register. + * \return The reference write instruction. + */ + static Instruction RefWrite(RegName ref, RegName value, RegName dst); + Instruction(); Instruction(const Instruction& instr); Instruction& operator=(const Instruction& instr); diff --git a/python/tvm/relay/backend/_backend.py b/python/tvm/relay/backend/_backend.py index 65b0c0ba87c7..b83e32f1f305 100644 --- a/python/tvm/relay/backend/_backend.py +++ b/python/tvm/relay/backend/_backend.py @@ -96,4 +96,9 @@ def _tensor_constant_repr(tvalue): return str(tvalue.data.asnumpy()) +@tvm._ffi.register_func("relay._ndarray_repr") +def _ndarray_repr(tvalue): + return str(tvalue.asnumpy()) + + tvm._ffi._init_api("relay.backend", __name__) diff --git a/src/relay/backend/vm/compiler.cc b/src/relay/backend/vm/compiler.cc index f652644afa3c..450144431952 100644 --- a/src/relay/backend/vm/compiler.cc +++ b/src/relay/backend/vm/compiler.cc @@ -328,6 +328,9 @@ class VMFunctionCompiler : ExprFunctor { case Opcode::Move: case Opcode::InvokeClosure: case Opcode::DeviceCopy: + case Opcode::RefCreate: + case Opcode::RefRead: + case Opcode::RefWrite: last_register_ = instr.dst; break; case Opcode::InvokePacked: @@ -405,6 +408,26 @@ class VMFunctionCompiler : ExprFunctor { Emit(Instruction::AllocClosure(it->second, 0, {}, NewRegister())); } + void VisitExpr_(const RefCreateNode* ref_create) { + this->VisitExpr(ref_create->value); + auto value_register = last_register_; + Emit(Instruction::RefCreate(value_register, NewRegister())); + } + + void VisitExpr_(const RefReadNode* ref_read) { + this->VisitExpr(ref_read->ref); + auto ref_register = last_register_; + Emit(Instruction::RefRead(ref_register, NewRegister())); + } + + void VisitExpr_(const RefWriteNode* ref_write) { + this->VisitExpr(ref_write->ref); + auto ref_register = last_register_; + this->VisitExpr(ref_write->value); + auto value_register = last_register_; + Emit(Instruction::RefWrite(ref_register, value_register, NewRegister())); + } + void VisitExpr_(const IfNode* if_node) { this->VisitExpr(if_node->cond); diff --git a/src/relay/transforms/dead_code.cc b/src/relay/transforms/dead_code.cc index 2e7c08a684dc..bc740b3b23fd 100644 --- a/src/relay/transforms/dead_code.cc +++ b/src/relay/transforms/dead_code.cc @@ -82,7 +82,8 @@ class Eliminator : private ExprMutator { Expr VisitExpr_(const LetNode* op) final { Var v = op->var; - if (HasLet(v)) { + // TODO(@jroesch, @altanh, @M.K.): fix DCE with refs (#6803) + if (HasLet(v) || op->value.as()) { return Let(v, VisitExpr(op->value), VisitExpr(op->body)); } else { return VisitExpr(op->body); diff --git a/src/runtime/vm/bytecode.cc b/src/runtime/vm/bytecode.cc index f82d708468f7..d82aadc6bdd8 100644 --- a/src/runtime/vm/bytecode.cc +++ b/src/runtime/vm/bytecode.cc @@ -127,6 +127,15 @@ Instruction::Instruction(const Instruction& instr) { this->src_device_type = instr.src_device_type; this->dst_device_type = instr.dst_device_type; return; + case Opcode::RefCreate: + this->ref_create = instr.ref_create; + return; + case Opcode::RefRead: + this->ref_read = instr.ref_read; + return; + case Opcode::RefWrite: + this->ref_write = instr.ref_write; + return; default: std::ostringstream out; out << "Invalid instruction " << static_cast(instr.op); @@ -233,6 +242,15 @@ Instruction& Instruction::operator=(const Instruction& instr) { this->src_device_type = instr.src_device_type; this->dst_device_type = instr.dst_device_type; return *this; + case Opcode::RefCreate: + this->ref_create = instr.ref_create; + return *this; + case Opcode::RefRead: + this->ref_read = instr.ref_read; + return *this; + case Opcode::RefWrite: + this->ref_write = instr.ref_write; + return *this; default: std::ostringstream out; out << "Invalid instruction " << static_cast(instr.op); @@ -255,6 +273,9 @@ Instruction::~Instruction() { case Opcode::ShapeOf: case Opcode::ReshapeTensor: case Opcode::DeviceCopy: + case Opcode::RefCreate: + case Opcode::RefRead: + case Opcode::RefWrite: case Opcode::Fatal: return; case Opcode::AllocTensor: @@ -491,6 +512,31 @@ Instruction Instruction::Move(RegName src, RegName dst) { return instr; } +Instruction Instruction::RefCreate(RegName value, RegName dst) { + Instruction instr; + instr.op = Opcode::RefCreate; + instr.dst = dst; + instr.ref_create.initial_value = value; + return instr; +} + +Instruction Instruction::RefRead(RegName ref, RegName dst) { + Instruction instr; + instr.op = Opcode::RefRead; + instr.dst = dst; + instr.ref_read.ref = ref; + return instr; +} + +Instruction Instruction::RefWrite(RegName ref, RegName value, RegName dst) { + Instruction instr; + instr.op = Opcode::RefWrite; + instr.dst = dst; + instr.ref_write.ref = ref; + instr.ref_write.value = value; + return instr; +} + void DLDatatypePrint(std::ostream& os, const DLDataType& dtype) { switch (dtype.code) { case kDLInt: @@ -626,6 +672,19 @@ void InstructionPrint(std::ostream& os, const Instruction& instr) { << instr.src_device_type; break; } + case Opcode::RefCreate: { + os << "ref_create $" << instr.dst << " $" << instr.ref_create.initial_value; + break; + } + case Opcode::RefRead: { + os << "ref_ref $" << instr.dst << " $" << instr.ref_read.ref; + break; + } + case Opcode::RefWrite: { + os << "ref_write $" << instr.dst << " $" << instr.ref_write.ref << " $" + << instr.ref_write.value; + break; + } default: LOG(FATAL) << "should never hit this case" << static_cast(instr.op); break; diff --git a/src/runtime/vm/executable.cc b/src/runtime/vm/executable.cc index eb1707b25aa3..fa6371f51655 100644 --- a/src/runtime/vm/executable.cc +++ b/src/runtime/vm/executable.cc @@ -442,6 +442,21 @@ VMInstructionSerializer SerializeInstruction(const Instruction& instr) { fields.assign({instr.src, instr.src_device_type, instr.dst_device_type, instr.dst}); break; } + case Opcode::RefCreate: { + // Number of fields = 2 + fields.assign({instr.ref_create.initial_value, instr.dst}); + break; + } + case Opcode::RefRead: { + // Number of fields = 2 + fields.assign({instr.ref_read.ref, instr.dst}); + break; + } + case Opcode::RefWrite: { + // Number of fields = 3 + fields.assign({instr.ref_write.ref, instr.ref_write.value, instr.dst}); + break; + } default: LOG(FATAL) << "Invalid opcode" << static_cast(instr.op); break; @@ -734,6 +749,21 @@ Instruction DeserializeInstruction(const VMInstructionSerializer& instr) { return Instruction::DeviceCopy(instr.fields[0], instr.fields[1], instr.fields[2], instr.fields[3]); } + case Opcode::RefCreate: { + // Number of fields = 2 + DCHECK_EQ(instr.fields.size(), 2U); + return Instruction::RefCreate(instr.fields[0], instr.fields[1]); + } + case Opcode::RefRead: { + // Number of fields = 2 + DCHECK_EQ(instr.fields.size(), 2U); + return Instruction::RefRead(instr.fields[0], instr.fields[1]); + } + case Opcode::RefWrite: { + // Number of fields = 3 + DCHECK_EQ(instr.fields.size(), 3U); + return Instruction::RefWrite(instr.fields[0], instr.fields[1], instr.fields[2]); + } default: LOG(FATAL) << "Invalid opcode" << instr.opcode; return Instruction(); diff --git a/src/runtime/vm/vm.cc b/src/runtime/vm/vm.cc index 473b5d759272..b8ae9c92d3b7 100644 --- a/src/runtime/vm/vm.cc +++ b/src/runtime/vm/vm.cc @@ -262,6 +262,12 @@ void VirtualMachine::InvokePacked(Index packed_index, const PackedFunc& func, In func.CallPacked(TVMArgs(values.data(), codes.data(), arity), &rv); } +std::string DebugPrint(NDArray array) { + const PackedFunc* fprint = Registry::Get("relay._ndarray_repr"); + ICHECK(fprint) << "unable to find printing function for constants"; + return (*fprint)(array); +} + void VirtualMachine::LoadExecutable(const Executable* exec) { ICHECK(exec) << "The executable is not created yet."; exec_ = exec; @@ -595,6 +601,29 @@ void VirtualMachine::RunLoop() { pc_++; goto main_loop; } + case Opcode::RefRead: { + ADT ref = Downcast(ReadRegister(instr.ref_read.ref)); + WriteRegister(instr.dst, ref[0]); + pc_++; + goto main_loop; + } + case Opcode::RefWrite: { + ADT ref = Downcast(ReadRegister(instr.ref_write.ref)); + ObjectRef value = ReadRegister(instr.ref_write.value); + // Not sure about this being best way to implement mutable thing. + ADTObj* mut_array = const_cast(ref.as()); + ObjectRef& inner_value = (*mut_array)[0]; + inner_value = value; + pc_++; + goto main_loop; + } + case Opcode::RefCreate: { + auto value = {ReadRegister(instr.ref_create.initial_value)}; + auto ref = ADT::Tuple(value); + WriteRegister(instr.dst, ref); + pc_++; + goto main_loop; + } default: LOG(FATAL) << "Unknown instruction opcode: " << int(instr.op); } diff --git a/tests/python/relay/test_vm.py b/tests/python/relay/test_vm.py index 6958010176e3..750026268e72 100644 --- a/tests/python/relay/test_vm.py +++ b/tests/python/relay/test_vm.py @@ -770,6 +770,23 @@ def test_vm_reshape_tuple(x_shape=(1, 4, 2), y_shape=(1, 2, 10)): tvm.testing.assert_allclose(res.asnumpy(), np.reshape(x_data, (1, -1))) +@tvm.testing.uses_gpu +def test_vm_ref_create_read_write(): + scope = ScopeBuilder() + ref_const = relay.const(1.0, dtype="float32") + ref_create = scope.let("ref_create", relay.RefCreate(ref_const)) + read_ref = scope.let("read_ref", relay.RefRead(ref_create)) + new_value = read_ref + relay.const(2.0, dtype="float32") + scope.let("", relay.RefWrite(ref_create, new_value)) + scope.ret(relay.RefRead(ref_create)) + + f = relay.Function([], scope.get()) + + for tgt, ctx in tvm.testing.enabled_targets(): + res = veval(f, ctx=ctx, target=tgt) + tvm.testing.assert_allclose(res.asnumpy(), np.array(3)) + + def test_constant_shape_with_external_codegen(): mod = tvm.IRModule() shape = (relay.Any(), 25)