Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions include/tvm/relay/interpreter.h
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@

#include <tvm/ir/module.h>
#include <tvm/relay/expr.h>
#include <tvm/relay/function.h>
#include <tvm/runtime/container.h>
#include <tvm/runtime/object.h>
#include <tvm/target/target.h>
Expand Down
38 changes: 38 additions & 0 deletions include/tvm/runtime/vm/bytecode.h
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,9 @@ enum class Opcode {
ShapeOf = 17U,
ReshapeTensor = 18U,
DeviceCopy = 19U,
RefCreate = 20U,
RefRead = 21U,
RefWrite = 22U,
};

/*! \brief A single virtual machine instruction.
Expand Down Expand Up @@ -215,6 +218,16 @@ struct Instruction {
/*! \brief The destination device type. */
Index dst_device_type;
};
struct /* RefCreate Operands */ {
RegName initial_value;
} ref_create;
struct /* RefRead */ {
RegName ref;
} ref_read;
struct /* RefWrite */ {
RegName ref;
RegName value;
} ref_write;
};

/*!
Expand Down Expand Up @@ -384,6 +397,31 @@ struct Instruction {
static Instruction DeviceCopy(RegName src, Index src_device_type, Index dst_device_type,
RegName dst);

/*!
* \brief Create a reference.
* \param value The register corresponding to the initial value of the reference.
* \param dst The destination register.
* \return The reference creation instruction.
*/
static Instruction RefCreate(RegName value, RegName dst);

/*!
* \brief Read a value from a reference.
* \param ref The register to read from.
* \param dst The destination register.
* \return The reference read instruction.
*/
static Instruction RefRead(RegName ref, RegName dst);

/*!
* \brief Write a value to the reference.
* \param ref The register to write to.
* \param value The value to write to the register.
* \param dst The destination register.
* \return The reference write instruction.
*/
static Instruction RefWrite(RegName ref, RegName value, RegName dst);

Instruction();
Instruction(const Instruction& instr);
Instruction& operator=(const Instruction& instr);
Expand Down
5 changes: 5 additions & 0 deletions python/tvm/relay/backend/_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,4 +96,9 @@ def _tensor_constant_repr(tvalue):
return str(tvalue.data.asnumpy())


@tvm._ffi.register_func("relay._ndarray_repr")
def _ndarray_repr(tvalue):
return str(tvalue.asnumpy())


tvm._ffi._init_api("relay.backend", __name__)
23 changes: 23 additions & 0 deletions src/relay/backend/vm/compiler.cc
Original file line number Diff line number Diff line change
Expand Up @@ -328,6 +328,9 @@ class VMFunctionCompiler : ExprFunctor<void(const Expr& expr)> {
case Opcode::Move:
case Opcode::InvokeClosure:
case Opcode::DeviceCopy:
case Opcode::RefCreate:
case Opcode::RefRead:
case Opcode::RefWrite:

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

RefWrite doesn't have last register, right? Maybe set the last_register_ to -1?

last_register_ = instr.dst;
break;
case Opcode::InvokePacked:
Expand Down Expand Up @@ -405,6 +408,26 @@ class VMFunctionCompiler : ExprFunctor<void(const Expr& expr)> {
Emit(Instruction::AllocClosure(it->second, 0, {}, NewRegister()));
}

void VisitExpr_(const RefCreateNode* ref_create) {
this->VisitExpr(ref_create->value);
auto value_register = last_register_;
Emit(Instruction::RefCreate(value_register, NewRegister()));
}

void VisitExpr_(const RefReadNode* ref_read) {
this->VisitExpr(ref_read->ref);
auto ref_register = last_register_;
Emit(Instruction::RefRead(ref_register, NewRegister()));
}

void VisitExpr_(const RefWriteNode* ref_write) {
this->VisitExpr(ref_write->ref);
auto ref_register = last_register_;
this->VisitExpr(ref_write->value);
auto value_register = last_register_;
Emit(Instruction::RefWrite(ref_register, value_register, NewRegister()));
}

void VisitExpr_(const IfNode* if_node) {
this->VisitExpr(if_node->cond);

Expand Down
3 changes: 2 additions & 1 deletion src/relay/transforms/dead_code.cc
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,8 @@ class Eliminator : private ExprMutator {

Expr VisitExpr_(const LetNode* op) final {
Var v = op->var;
if (HasLet(v)) {
// TODO(@jroesch, @altanh, @M.K.): fix DCE with refs (#6803)
if (HasLet(v) || op->value.as<RefWriteNode>()) {

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't believe this fully fixes DCE due to nesting ref_write deep inside lets that get DCEd

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This should stop it from removing RefWrites anywhere and by extension keep all ref operations alive as the reference is now live too.

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I will give you a test case for this since I made this exact change when debugging and it wasn't enough to fix the problem. (In fact I remember segfaulting...)

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

BTW, out of the scope of this PR, do we need to do some alias analysis to handle reference for DCE?

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I believe so, for now we should probably require DCE to error when the incoming module contains references. @jroesch and I are planning on writing an alias analysis pass soon.

return Let(v, VisitExpr(op->value), VisitExpr(op->body));
} else {
return VisitExpr(op->body);
Expand Down
59 changes: 59 additions & 0 deletions src/runtime/vm/bytecode.cc
Original file line number Diff line number Diff line change
Expand Up @@ -127,6 +127,15 @@ Instruction::Instruction(const Instruction& instr) {
this->src_device_type = instr.src_device_type;
this->dst_device_type = instr.dst_device_type;
return;
case Opcode::RefCreate:
this->ref_create = instr.ref_create;
return;
case Opcode::RefRead:
this->ref_read = instr.ref_read;
return;
case Opcode::RefWrite:
this->ref_write = instr.ref_write;
return;
default:
std::ostringstream out;
out << "Invalid instruction " << static_cast<int>(instr.op);
Expand Down Expand Up @@ -233,6 +242,15 @@ Instruction& Instruction::operator=(const Instruction& instr) {
this->src_device_type = instr.src_device_type;
this->dst_device_type = instr.dst_device_type;
return *this;
case Opcode::RefCreate:
this->ref_create = instr.ref_create;
return *this;
case Opcode::RefRead:
this->ref_read = instr.ref_read;
return *this;
case Opcode::RefWrite:
this->ref_write = instr.ref_write;
return *this;
default:
std::ostringstream out;
out << "Invalid instruction " << static_cast<int>(instr.op);
Expand All @@ -255,6 +273,9 @@ Instruction::~Instruction() {
case Opcode::ShapeOf:
case Opcode::ReshapeTensor:
case Opcode::DeviceCopy:
case Opcode::RefCreate:
case Opcode::RefRead:
case Opcode::RefWrite:
case Opcode::Fatal:
return;
case Opcode::AllocTensor:
Expand Down Expand Up @@ -491,6 +512,31 @@ Instruction Instruction::Move(RegName src, RegName dst) {
return instr;
}

Instruction Instruction::RefCreate(RegName value, RegName dst) {
Instruction instr;
instr.op = Opcode::RefCreate;
instr.dst = dst;
instr.ref_create.initial_value = value;
return instr;
}

Instruction Instruction::RefRead(RegName ref, RegName dst) {
Instruction instr;
instr.op = Opcode::RefRead;
instr.dst = dst;
instr.ref_read.ref = ref;
return instr;
}

Instruction Instruction::RefWrite(RegName ref, RegName value, RegName dst) {
Instruction instr;
instr.op = Opcode::RefWrite;
instr.dst = dst;
instr.ref_write.ref = ref;
instr.ref_write.value = value;
return instr;
}

void DLDatatypePrint(std::ostream& os, const DLDataType& dtype) {
switch (dtype.code) {
case kDLInt:
Expand Down Expand Up @@ -626,6 +672,19 @@ void InstructionPrint(std::ostream& os, const Instruction& instr) {
<< instr.src_device_type;
break;
}
case Opcode::RefCreate: {
os << "ref_create $" << instr.dst << " $" << instr.ref_create.initial_value;
break;
}
case Opcode::RefRead: {
os << "ref_ref $" << instr.dst << " $" << instr.ref_read.ref;
break;
}
case Opcode::RefWrite: {
os << "ref_write $" << instr.dst << " $" << instr.ref_write.ref << " $"
<< instr.ref_write.value;
break;
}
default:
LOG(FATAL) << "should never hit this case" << static_cast<int>(instr.op);
break;
Expand Down
30 changes: 30 additions & 0 deletions src/runtime/vm/executable.cc
Original file line number Diff line number Diff line change
Expand Up @@ -442,6 +442,21 @@ VMInstructionSerializer SerializeInstruction(const Instruction& instr) {
fields.assign({instr.src, instr.src_device_type, instr.dst_device_type, instr.dst});
break;
}
case Opcode::RefCreate: {
// Number of fields = 2
fields.assign({instr.ref_create.initial_value, instr.dst});
break;
}
case Opcode::RefRead: {
// Number of fields = 2
fields.assign({instr.ref_read.ref, instr.dst});
break;
}
case Opcode::RefWrite: {
// Number of fields = 3
fields.assign({instr.ref_write.ref, instr.ref_write.value, instr.dst});
break;
}
default:
LOG(FATAL) << "Invalid opcode" << static_cast<int>(instr.op);
break;
Expand Down Expand Up @@ -734,6 +749,21 @@ Instruction DeserializeInstruction(const VMInstructionSerializer& instr) {
return Instruction::DeviceCopy(instr.fields[0], instr.fields[1], instr.fields[2],
instr.fields[3]);
}
case Opcode::RefCreate: {
// Number of fields = 2
DCHECK_EQ(instr.fields.size(), 2U);
return Instruction::RefCreate(instr.fields[0], instr.fields[1]);
}
case Opcode::RefRead: {
// Number of fields = 2
DCHECK_EQ(instr.fields.size(), 2U);
return Instruction::RefRead(instr.fields[0], instr.fields[1]);
}
case Opcode::RefWrite: {
// Number of fields = 3
DCHECK_EQ(instr.fields.size(), 3U);
return Instruction::RefWrite(instr.fields[0], instr.fields[1], instr.fields[2]);
}
default:
LOG(FATAL) << "Invalid opcode" << instr.opcode;
return Instruction();
Expand Down
29 changes: 29 additions & 0 deletions src/runtime/vm/vm.cc
Original file line number Diff line number Diff line change
Expand Up @@ -262,6 +262,12 @@ void VirtualMachine::InvokePacked(Index packed_index, const PackedFunc& func, In
func.CallPacked(TVMArgs(values.data(), codes.data(), arity), &rv);
}

std::string DebugPrint(NDArray array) {
const PackedFunc* fprint = Registry::Get("relay._ndarray_repr");
ICHECK(fprint) << "unable to find printing function for constants";
return (*fprint)(array);
}

Comment on lines +265 to +270

@mbrookhart mbrookhart Oct 30, 2020

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should we make this part of the ReprPrinter in ndarray.cc? I've been using something very similar in the VM, but I feel like we should make it more general.

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Agree on this

void VirtualMachine::LoadExecutable(const Executable* exec) {
ICHECK(exec) << "The executable is not created yet.";
exec_ = exec;
Expand Down Expand Up @@ -595,6 +601,29 @@ void VirtualMachine::RunLoop() {
pc_++;
goto main_loop;
}
case Opcode::RefRead: {
ADT ref = Downcast<ADT>(ReadRegister(instr.ref_read.ref));
WriteRegister(instr.dst, ref[0]);
pc_++;
goto main_loop;
}
case Opcode::RefWrite: {
ADT ref = Downcast<ADT>(ReadRegister(instr.ref_write.ref));
ObjectRef value = ReadRegister(instr.ref_write.value);
// Not sure about this being best way to implement mutable thing.
ADTObj* mut_array = const_cast<ADTObj*>(ref.as<ADTObj>());
ObjectRef& inner_value = (*mut_array)[0];
inner_value = value;
pc_++;
goto main_loop;
}
case Opcode::RefCreate: {
auto value = {ReadRegister(instr.ref_create.initial_value)};
auto ref = ADT::Tuple(value);
WriteRegister(instr.dst, ref);
pc_++;
goto main_loop;
}
default:
LOG(FATAL) << "Unknown instruction opcode: " << int(instr.op);
}
Expand Down
17 changes: 17 additions & 0 deletions tests/python/relay/test_vm.py
Original file line number Diff line number Diff line change
Expand Up @@ -770,6 +770,23 @@ def test_vm_reshape_tuple(x_shape=(1, 4, 2), y_shape=(1, 2, 10)):
tvm.testing.assert_allclose(res.asnumpy(), np.reshape(x_data, (1, -1)))


@tvm.testing.uses_gpu
def test_vm_ref_create_read_write():
scope = ScopeBuilder()
ref_const = relay.const(1.0, dtype="float32")
ref_create = scope.let("ref_create", relay.RefCreate(ref_const))
read_ref = scope.let("read_ref", relay.RefRead(ref_create))
new_value = read_ref + relay.const(2.0, dtype="float32")
scope.let("", relay.RefWrite(ref_create, new_value))
scope.ret(relay.RefRead(ref_create))

f = relay.Function([], scope.get())

for tgt, ctx in tvm.testing.enabled_targets():
res = veval(f, ctx=ctx, target=tgt)
tvm.testing.assert_allclose(res.asnumpy(), np.array(3))


def test_constant_shape_with_external_codegen():
mod = tvm.IRModule()
shape = (relay.Any(), 25)
Expand Down