Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion include/tvm/ir/expr.h
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@
#include <tvm/ir/source_map.h>
#include <tvm/ir/type.h>
#include <tvm/node/cast.h>
#include <tvm/node/repr_printer.h>
#include <tvm/node/repr.h>
#include <tvm/node/script_printer.h>
#include <tvm/runtime/object.h>

Expand Down
28 changes: 11 additions & 17 deletions include/tvm/node/functor.h
Original file line number Diff line number Diff line change
Expand Up @@ -160,38 +160,32 @@ class NodeFunctor<R(const ObjectRef& n, Args...)> {
* \brief Useful macro to set NodeFunctor dispatch in a global static field.
*
* \code
* // Use NodeFunctor to implement ReprPrinter similar to Visitor Pattern.
* // Use NodeFunctor to implement TVMScriptPrinter similar to Visitor Pattern.
* // vtable allows easy patch of new Node types, without changing
* // interface of ReprPrinter.
* // the interface of TVMScriptPrinter.
*
* class ReprPrinter {
* class TVMScriptPrinter {
* public:
* std::ostream& stream;
* // the dispatch function.
* void print(Expr e) {
* const static FType& f = *vtable();
* f(e, this);
* static std::string Script(const ObjectRef& node, const PrinterConfig& cfg) {
* return vtable()(node, cfg);
* }
*
* using FType = NodeFunctor<void (const ObjectRef&, ReprPrinter* )>;
* using FType = NodeFunctor<std::string(const ObjectRef&, const PrinterConfig&)>;
* // function to return global function table
* static FType& vtable();
* };
*
* // in cpp/cc file
* ReprPrinter::FType& ReprPrinter::vtable() { // NOLINT(*)
* TVMScriptPrinter::FType& TVMScriptPrinter::vtable() {
* static FType inst; return inst;
* }
*
* TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
* .set_dispatch<Add>([](const ObjectRef& ref, ReprPrinter* p) {
* auto* n = static_cast<const Add*>(ref.get());
* p->print(n->a);
* p->stream << '+'
* p->print(n->b);
* TVM_STATIC_IR_FUNCTOR(TVMScriptPrinter, vtable)
* .set_dispatch<AddNode>([](const ObjectRef& ref, const PrinterConfig& cfg) {
* auto* n = static_cast<const AddNode*>(ref.get());
* return Script(n->a, cfg) + " + " + Script(n->b, cfg);
* });
*
*
* \endcode
*
* \param ClsName The name of the class
Expand Down
118 changes: 118 additions & 0 deletions include/tvm/node/repr.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,118 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
/*!
* \file tvm/node/repr.h
* \brief ostream operator<< for ObjectRef, Any, and Variant, delegating to
* ffi::ReprPrint. Also re-exports the Dump() debug helpers.
*
* Include this header wherever you need `os << some_objectref` and you are
* no longer pulling in the legacy repr_printer.h.
*/
#ifndef TVM_NODE_REPR_H_
#define TVM_NODE_REPR_H_

#include <tvm/ffi/extra/dataclass.h>
#include <tvm/ffi/reflection/access_path.h>
#include <tvm/runtime/object.h>

#include <iostream>

namespace tvm {

/*!
* \brief Dump the node to stderr, used for debug purposes.
* \param node The input node
*/
TVM_DLL void Dump(const runtime::ObjectRef& node);

/*!
* \brief Dump the node to stderr, used for debug purposes.
* \param node The input node
*/
TVM_DLL void Dump(const runtime::Object* node);

} // namespace tvm

namespace tvm {
namespace ffi {

// ostream << ObjectRef — delegates to ffi::ReprPrint
inline std::ostream& operator<<(std::ostream& os, const ObjectRef& n) { // NOLINT(*)
return os << ffi::ReprPrint(Any(n));
}

// ostream << Any — delegates to ffi::ReprPrint
inline std::ostream& operator<<(std::ostream& os, const Any& n) { // NOLINT(*)
return os << ffi::ReprPrint(n);
}

// ostream << Variant<...> — delegates to ffi::ReprPrint
template <typename... V>
inline std::ostream& operator<<(std::ostream& os, const ffi::Variant<V...>& n) { // NOLINT(*)
return os << ffi::ReprPrint(Any(n));
}

namespace reflection {

inline std::ostream& operator<<(std::ostream& os, const AccessStep& step) {
namespace refl = ffi::reflection;
switch (step->kind) {
case refl::AccessKind::kAttr: {
os << '.' << step->key.cast<ffi::String>();
return os;
}
case refl::AccessKind::kArrayItem: {
os << "[" << step->key.cast<int64_t>() << "]";
return os;
}
case refl::AccessKind::kMapItem: {
os << "[" << step->key << "]";
return os;
}
case refl::AccessKind::kAttrMissing: {
os << ".<missing attr " << step->key.cast<ffi::String>() << "`>";
return os;
}
case refl::AccessKind::kArrayItemMissing: {
os << "[<missing item at " << step->key.cast<int64_t>() << ">]";
return os;
}
case refl::AccessKind::kMapItemMissing: {
os << "[<missing item at " << step->key << ">]";
return os;
}
default: {
TVM_FFI_THROW(InternalError) << "Unknown access step kind: " << static_cast<int>(step->kind);
}
}
return os;
}

inline std::ostream& operator<<(std::ostream& os, const AccessPath& path) {
ffi::Array<AccessStep> steps = path->ToSteps();
os << "<root>";
for (const auto& step : steps) {
os << step;
}
return os;
}
} // namespace reflection
} // namespace ffi
} // namespace tvm
#endif // TVM_NODE_REPR_H_
116 changes: 4 additions & 112 deletions include/tvm/node/repr_printer.h
Original file line number Diff line number Diff line change
Expand Up @@ -18,121 +18,13 @@
*/
/*!
* \file tvm/node/repr_printer.h
* \brief Printer class to print repr string of each AST/IR nodes.
* \brief DEPRECATED: The legacy ReprPrinter has been replaced by
* ffi::ReprPrint. This header is kept as an empty shim;
* include <tvm/node/repr.h> instead.
*/
#ifndef TVM_NODE_REPR_PRINTER_H_
#define TVM_NODE_REPR_PRINTER_H_

#include <tvm/ffi/reflection/access_path.h>
#include <tvm/node/functor.h>
#include <tvm/node/script_printer.h>
#include <tvm/node/repr.h>

#include <iostream>
#include <string>

namespace tvm {
/*! \brief A printer class to print the AST/IR nodes. */
class ReprPrinter {
public:
/*! \brief The output stream */
std::ostream& stream;
/*! \brief The indentation level. */
int indent{0};

explicit ReprPrinter(std::ostream& stream) // NOLINT(*)
: stream(stream) {}

/*! \brief The node to be printed. */
TVM_DLL void Print(const ObjectRef& node);
/*! \brief The node to be printed. */
TVM_DLL void Print(const ffi::Any& node);
/*! \brief Print indent to the stream */
TVM_DLL void PrintIndent();
// Allow registration to be printer.
using FType = NodeFunctor<void(const ObjectRef&, ReprPrinter*)>;
TVM_DLL static FType& vtable();
};

/*!
* \brief Dump the node to stderr, used for debug purposes.
* \param node The input node
*/
TVM_DLL void Dump(const runtime::ObjectRef& node);

/*!
* \brief Dump the node to stderr, used for debug purposes.
* \param node The input node
*/
TVM_DLL void Dump(const runtime::Object* node);

} // namespace tvm

namespace tvm {
namespace ffi {
// default print function for all objects
// provide in the runtime namespace as this is where objectref originally comes from.
inline std::ostream& operator<<(std::ostream& os, const ObjectRef& n) { // NOLINT(*)
ReprPrinter(os).Print(n);
return os;
}

// default print function for any
inline std::ostream& operator<<(std::ostream& os, const Any& n) { // NOLINT(*)
ReprPrinter(os).Print(n);
return os;
}

template <typename... V>
inline std::ostream& operator<<(std::ostream& os, const ffi::Variant<V...>& n) { // NOLINT(*)
ReprPrinter(os).Print(Any(n));
return os;
}

namespace reflection {

inline std::ostream& operator<<(std::ostream& os, const AccessStep& step) {
namespace refl = ffi::reflection;
switch (step->kind) {
case refl::AccessKind::kAttr: {
os << '.' << step->key.cast<ffi::String>();
return os;
}
case refl::AccessKind::kArrayItem: {
os << "[" << step->key.cast<int64_t>() << "]";
return os;
}
case refl::AccessKind::kMapItem: {
os << "[" << step->key << "]";
return os;
}
case refl::AccessKind::kAttrMissing: {
os << ".<missing attr " << step->key.cast<ffi::String>() << "`>";
return os;
}
case refl::AccessKind::kArrayItemMissing: {
os << "[<missing item at " << step->key.cast<int64_t>() << ">]";
return os;
}
case refl::AccessKind::kMapItemMissing: {
os << "[<missing item at " << step->key << ">]";
return os;
}
default: {
TVM_FFI_THROW(InternalError) << "Unknown access step kind: " << static_cast<int>(step->kind);
}
}
return os;
}

inline std::ostream& operator<<(std::ostream& os, const AccessPath& path) {
ffi::Array<AccessStep> steps = path->ToSteps();
os << "<root>";
for (const auto& step : steps) {
os << step;
}
return os;
}
} // namespace reflection
} // namespace ffi
} // namespace tvm
#endif // TVM_NODE_REPR_PRINTER_H_
2 changes: 1 addition & 1 deletion include/tvm/node/script_printer.h
Original file line number Diff line number Diff line change
Expand Up @@ -161,7 +161,7 @@ class TVM_DLL PrinterConfig : public ObjectRef {
PrinterConfigNode);
};

/*! \brief Legacy behavior of ReprPrinter. */
/*! \brief TVMScript-based printer for IR nodes. */
class TVMScriptPrinter {
public:
/* Convert the object to TVMScript format */
Expand Down
2 changes: 1 addition & 1 deletion include/tvm/relax/exec_builder.h
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@
#include <tvm/ffi/function.h>
#include <tvm/ffi/reflection/registry.h>
#include <tvm/ir/expr.h>
#include <tvm/node/repr_printer.h>
#include <tvm/node/repr.h>
#include <tvm/runtime/object.h>
#include <tvm/runtime/vm/bytecode.h>
#include <tvm/runtime/vm/executable.h>
Expand Down
2 changes: 2 additions & 0 deletions include/tvm/s_tir/meta_schedule/mutator.h
Original file line number Diff line number Diff line change
Expand Up @@ -162,6 +162,8 @@ class PyMutatorNode : public MutatorNode {
FAsString f_as_string;

static void RegisterReflection() {
namespace refl = tvm::ffi::reflection;
refl::ObjectDef<PyMutatorNode>();
// `f_initialize_with_tune_context` is not registered
// `f_apply` is not registered
// `f_clone` is not registered
Expand Down
35 changes: 1 addition & 34 deletions src/arith/canonical_simplify.cc
Original file line number Diff line number Diff line change
Expand Up @@ -535,40 +535,7 @@ void SumExprNode::AddToSelf(const SumExpr& other, int64_t scale) {
this->AddToSelf(other->base * scale);
}

TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
.set_dispatch<SplitExprNode>([](const ObjectRef& node, ReprPrinter* p) {
auto* op = static_cast<const SplitExprNode*>(node.get());
auto factor_str = [](int64_t f) {
return f == SplitExprNode::kPosInf ? std::string("+inf") : std::to_string(f);
};
p->stream << "split(";
p->Print(op->index);
p->stream << ", lower=" << factor_str(op->lower_factor)
<< ", upper=" << factor_str(op->upper_factor) << ", scale=" << op->scale
<< ", div_mode=";
switch (op->div_mode) {
// No "default", so that the compiler will emit a warning if more div modes are
// added that are not covered by the switch.
case kTruncDiv:
p->stream << "truncdiv";
break;
case kFloorDiv:
p->stream << "floordiv";
break;
}
p->stream << ')';
});

TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
.set_dispatch<SumExprNode>([](const ObjectRef& node, ReprPrinter* p) {
auto* op = static_cast<const SumExprNode*>(node.get());
p->stream << "sum(base=" << op->base;
for (const SplitExpr& s : op->args) {
p->stream << ", ";
p->Print(s);
}
p->stream << ')';
});
// Pattern A (RM): auto-default repr from reflection for SplitExprNode and SumExprNode.

// Sub-class RewriteSimplifier::Impl to take benefit of
// rewriter for condition simplification etc.
Expand Down
10 changes: 1 addition & 9 deletions src/arith/const_int_bound.cc
Original file line number Diff line number Diff line change
Expand Up @@ -67,15 +67,7 @@ inline void PrintBoundValue(std::ostream& os, int64_t val) {
}
}

TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
.set_dispatch<ConstIntBoundNode>([](const ObjectRef& node, ReprPrinter* p) {
auto* op = static_cast<const ConstIntBoundNode*>(node.get());
p->stream << "ConstIntBound[";
PrintBoundValue(p->stream, op->min_value);
p->stream << ',';
PrintBoundValue(p->stream, op->max_value);
p->stream << ']';
});
// Pattern A (RM): auto-default repr from reflection.

// internal entry for const int bound
struct ConstIntBoundAnalyzer::Entry {
Expand Down
Loading
Loading