Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
53 changes: 26 additions & 27 deletions python/tvm/script/tir/special_stmt.py
Original file line number Diff line number Diff line change
Expand Up @@ -317,8 +317,7 @@ class BlockReads(SpecialStmt):

def __init__(self):
def reads(
read_regions: Union[BufferSlice, List[BufferSlice]],
*other_regions: BufferSlice,
*read_regions: Union[BufferSlice, List[BufferSlice]],
span: Span = None,
):
assert self.context, "call 'exit_scope' before 'enter_scope'"
Expand All @@ -335,16 +334,18 @@ def reads(
+ str(", ".join(str(x) for x in block_scope.reads)),
span,
)
if isinstance(read_regions, BufferSlice):
read_regions = [read_regions]
for region in other_regions:
read_regions.append(region)
if not isinstance(read_regions, list):
self.context.report_error(
"Incorrect input type. "
+ f"Expected BufferSlice or List[BufferSlice], but got {type(read_regions)}",
span,
)
if len(read_regions) > 1:
for read_region in read_regions:
if not isinstance(read_region, BufferSlice):
self.context.report_error(
"Incorrect input type. Expected *BufferSlice or List[BufferSlice],"
+ f" but got {type(read_regions)}",
span,
)
elif len(read_regions) == 1:
if isinstance(read_regions[0], list):
read_regions = read_regions[0]

block_scope.reads = read_regions

super().__init__(reads, def_symbol=False)
Expand All @@ -368,8 +369,7 @@ class BlockWrites(SpecialStmt):

def __init__(self):
def writes(
write_region: Union[BufferSlice, List[BufferSlice]],
*other_region: BufferSlice,
*write_regions: Union[BufferSlice, List[BufferSlice]],
span: Span = None,
):
assert self.context, "call 'exit_scope' before 'enter_scope'"
Expand All @@ -386,19 +386,18 @@ def writes(
+ str(", ".join(str(x) for x in block_scope.writes)),
span,
)
if isinstance(write_region, list):
pass
elif isinstance(write_region, BufferSlice):
write_region = [write_region]
for region in other_region:
write_region.append(region)
else:
self.context.report_error(
"Incorrect input type. "
+ f"Expected BufferSlice or List[BufferSlice], but got {type(write_region)}",
span,
)
block_scope.writes = write_region
if len(write_regions) > 1:
for write_region in write_regions:
if not isinstance(write_region, BufferSlice):
self.context.report_error(
"Incorrect input type. Expected *BufferSlice or List[BufferSlice],"
+ f" but got {type(write_regions)}",
span,
)
elif len(write_regions) == 1:
if isinstance(write_regions[0], list):
write_regions = write_regions[0]
block_scope.writes = write_regions

super().__init__(writes, def_symbol=False)

Expand Down
112 changes: 104 additions & 8 deletions src/printer/tvmscript_printer.cc
Original file line number Diff line number Diff line change
Expand Up @@ -206,6 +206,7 @@ class TVMScriptPrinter : public StmtFunctor<Doc(const Stmt&)>,
Doc PrintBlockVarRemaps();
Doc PrintBlockVars(const BlockRealizeNode* op);
Doc PrintBlockAttr(const BlockRealizeNode* op);
Doc PrintExpandedArray(const ArrayNode* op);
Doc PrintBlockBody(const BlockNode* op);
virtual Doc PrintBlockName(const BlockNode* block_op);
Doc PrintBufferRegion(const BufferRegionNode* op);
Expand All @@ -220,6 +221,13 @@ class TVMScriptPrinter : public StmtFunctor<Doc(const Stmt&)>,
Doc AllocBuf(const Buffer& buffer);
void TryDeallocVar(const Var& var);
bool ContainsOptionalInfo(const Stmt& stmt);
/*!
* \brief check if a buffer declaration has only 'shape' and 'dtype' arguments specified
* \param buffer The match buffer to be checked
*/
bool IsSimpleBuffer(const Buffer& buffer);
Doc PrintInlineBufferBind(const Buffer& buffer);
Doc PrintTuple(const ArrayNode* op);

/*! Helper functions for loop printing. */
/*!
Expand Down Expand Up @@ -404,7 +412,7 @@ Doc TVMScriptPrinter::AllocBufferDeclaration(const Buffer& buf) {
if (buf->offset_factor != 1 || print_factor_explicitly) {
doc << ", offset_factor=" << buf->offset_factor;
}
if (buf->buffer_type != 1) {
if (buf->buffer_type != BufferType::kDefault) {
doc << ", type=" << Doc::StrLiteral("auto");
}
return doc;
Expand Down Expand Up @@ -471,6 +479,60 @@ Doc TVMScriptPrinter::PrintMatchBufferRegion(const MatchBufferRegionNode* op) {
return doc;
}

// check if all arguments, except the first two, are specified for T.match_buffer
// if not, then this match buffer is printed out as T.buffer in prim_func arguments
bool TVMScriptPrinter::IsSimpleBuffer(const Buffer& buf) {
if (memo_var_.find(buf->data) != memo_var_.end()) {
return false;
}
if (!buf->strides.empty()) {
return false;
}
if (buf->elem_offset->IsInstance<VarNode>()) {
return false;
} else if (buf->elem_offset->IsInstance<IntImmNode>()) {
IntImm elem_offset = Downcast<IntImm>(buf->elem_offset);
if (elem_offset->value != 0) {
return false;
}
}
if (buf.scope() != "global") {
return false;
}
if (buf->data_alignment != runtime::kAllocAlignment) {
return false;
}
if (buf->offset_factor != 1) {
return false;
}
if (buf->buffer_type != BufferType::kDefault) {
return false;
}
return true;
}

Doc TVMScriptPrinter::PrintInlineBufferBind(const Buffer& buffer) {
Doc doc;
doc << tir_prefix_ << ".Buffer[" << PrintTuple(buffer->shape.as<ArrayNode>());
doc << ", " << PrintDType(buffer->dtype) << "]";
return doc;
}

// print array out as tuple with parentheses
Doc TVMScriptPrinter::PrintTuple(const ArrayNode* op) {
Doc doc;
doc << '(';
for (size_t i = 0; i < op->size(); ++i) {
if (i != 0) {
doc << ", ";
}
doc << Print(op->at(i));
}
if (op->size() == 1) doc << ",";
doc << ')';
Comment thread
shingjan marked this conversation as resolved.
return doc;
}

Doc TVMScriptPrinter::PrintCommReducer(const CommReducerNode* op) {
Doc doc;
int n_var = static_cast<int>(op->rhs.size());
Expand Down Expand Up @@ -1095,8 +1157,10 @@ Doc TVMScriptPrinter::PrintBlockAttr(const BlockRealizeNode* op) {
if (!is_one(op->predicate)) {
block_attr_doc << Doc::NewLine() << tir_prefix_ << ".where(" << Print(op->predicate) << ")";
}
block_attr_doc << Doc::NewLine() << tir_prefix_ << ".reads(" << Print(block_op->reads) << ")";
block_attr_doc << Doc::NewLine() << tir_prefix_ << ".writes(" << Print(block_op->writes) << ")";
block_attr_doc << Doc::NewLine() << tir_prefix_ << ".reads("
<< PrintExpandedArray(block_op->reads.as<ArrayNode>()) << ")";
block_attr_doc << Doc::NewLine() << tir_prefix_ << ".writes("
<< PrintExpandedArray(block_op->writes.as<ArrayNode>()) << ")";
if (!block_op->annotations.empty()) {
block_attr_doc << Doc::NewLine() << tir_prefix_ << ".block_attr({";
block_attr_doc << PrintAnnotations(block_op->annotations);
Expand All @@ -1105,6 +1169,19 @@ Doc TVMScriptPrinter::PrintBlockAttr(const BlockRealizeNode* op) {
return block_attr_doc;
}

// This function is to make sure arguments of T.reads() and T.writes() is not parsed by printer as a
// List. Therefore the brackets are removed before and after printing arguments out
Doc TVMScriptPrinter::PrintExpandedArray(const ArrayNode* op) {
Doc doc;
for (size_t i = 0; i < op->size(); ++i) {
if (i != 0) {
doc << ", ";
}
doc << Print(op->at(i));
}
return doc;
}

Doc TVMScriptPrinter::PrintBlockBody(const BlockNode* op) {
Doc body;
for (const auto& alloc_buf : op->alloc_buffers) {
Expand Down Expand Up @@ -1218,8 +1295,21 @@ Doc TVMScriptPrinter::PrintPrimFunc(const PrimFunc& primFunc) {
doc << "def " << (func2var_.find(op) == func2var_.end() ? "func" : func2var_[op]->name_hint)
<< "(";
std::vector<Doc> params;
std::unordered_set<Buffer, ObjectPtrHash, ObjectPtrEqual> simple_buf;
for (const auto& param : op->params) {
var_not_in_headers_.insert(param.get());
auto it = op->buffer_map.find(param);
// check if this param is a T.handle
if (it != op->buffer_map.end()) {
// check if this match_buffer has only the first two arguments specified
const Buffer& buf = (*it).second;
if (IsSimpleBuffer(buf)) {
simple_buf.insert(buf);
buf_not_in_headers_.insert(buf.get());
params.push_back(Print(buf) << ": " << PrintInlineBufferBind(buf));
continue;
}
}
params.push_back(Print(param) << ": " << Print(GetType(param)));
}
doc << PrintSep(params, Doc::Text(", ")) << ") -> " << Print(primFunc->ret_type) << ":";
Expand All @@ -1229,9 +1319,11 @@ Doc TVMScriptPrinter::PrintPrimFunc(const PrimFunc& primFunc) {
for (const auto& param : op->params) {
auto it = op->buffer_map.find(param);
if (it == op->buffer_map.end()) continue;
buf_not_in_headers_.insert((*it).second.get());
body << Print((*it).second) << " = " << tir_prefix_ << ".match_buffer(";
body << Print((*it).first) << ", " << memo_buf_decl_[(*it).second];
const Buffer& buf = (*it).second;
if (simple_buf.count(buf)) continue;
buf_not_in_headers_.insert(buf.get());
body << Print(buf) << " = " << tir_prefix_ << ".match_buffer(";
body << Print((*it).first) << ", " << memo_buf_decl_[buf];
body << ")" << Doc::NewLine();
}
// print body
Expand Down Expand Up @@ -1392,8 +1484,12 @@ Doc TVMScriptPrinter::PrintAnnotations(const Map<String, ObjectRef>& annotations
Doc TVMScriptPrinter::PrintLoop(const For& loop) {
Doc res;
res << "for " << Print(loop->loop_var) << " in " << tir_prefix_
<< "." + std::string(ForKind2String(loop->kind)) + "(" << Print(loop->min) << ", "
<< Print(loop->min + loop->extent);
<< "." + std::string(ForKind2String(loop->kind)) + "(";
if (is_zero(loop->min)) {
res << Print(loop->extent);
} else {
res << Print(loop->min) << ", " << Print(loop->min + loop->extent);
}
if (loop->thread_binding.defined()) {
res << ", thread=";
res << Print(loop->thread_binding.value()->thread_tag);
Expand Down
12 changes: 6 additions & 6 deletions tests/python/unittest/test_tvmscript_error_report.py
Original file line number Diff line number Diff line change
Expand Up @@ -544,10 +544,10 @@ def test_reorder_fail_nested_loop_inner():
with pytest.raises(tvm.tir.ScheduleError) as execinfo:
sch.reorder(k, i)
expected_sub_error_message = (
" for i in T.serial(0, 128):\n"
" for i in T.serial(128):\n"
" # tir.For#0\n"
" for j in T.serial(0, 128):\n"
" ^^^^^^^^^^^^^^^^^^^^^^^^^^\n"
" for j in T.serial(128):\n"
" ^^^^^^^^^^^^^^^^^^^^^^^\n"
)
assert expected_sub_error_message in str(execinfo.value)

Expand All @@ -560,9 +560,9 @@ def test_fuse_fail_nested_loop_outer():
sch.fuse(k, i)
expected_sub_error_message = (
" # tir.For#1\n"
" for i in T.serial(0, 128):\n"
" ^^^^^^^^^^^^^^^^^^^^^^^^^^\n"
" for j in T.serial(0, 128):\n"
" for i in T.serial(128):\n"
" ^^^^^^^^^^^^^^^^^^^^^^^\n"
" for j in T.serial(128):\n"
)
assert expected_sub_error_message in str(execinfo.value)

Expand Down
4 changes: 2 additions & 2 deletions tests/python/unittest/test_tvmscript_syntax_sugar.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,8 +118,8 @@ def elementwise_handle(
# match buffer - use buffer with kwargs
@T.prim_func
def elementwise_buffer_kwargs(
a: T.Buffer(shape=(128, 128, 128, 128), dtype="float32", elem_offset=None),
b: T.Buffer(shape=(128, 128, 128, 128), dtype="float32", elem_offset=None),
a: T.Buffer(shape=(128, 128, 128, 128), dtype="float32"),
b: T.Buffer(shape=(128, 128, 128, 128), dtype="float32"),
) -> None:
for i, j, k, l in T.grid(128, 128, 128, 128):
with T.block("B"):
Expand Down