diff --git a/src/tir/analysis/estimate_flops.cc b/src/tir/analysis/estimate_flops.cc index ee2869d99303..c4851e255f0e 100644 --- a/src/tir/analysis/estimate_flops.cc +++ b/src/tir/analysis/estimate_flops.cc @@ -138,6 +138,11 @@ class FlopEstimator : private ExprFunctor, } TResult VisitExpr_(const BufferLoadNode* op) override { return TResult(); } + TResult VisitStmt_(const AttrStmtNode* op) override { + TResult result = VisitStmt(op->body); + result += VisitExpr(op->value); + return result; + } TResult VisitStmt_(const BufferStoreNode* store) override { return VisitExpr(store->value); } TResult VisitStmt_(const BlockRealizeNode* block) override { return VisitStmt(block->block->body); @@ -186,6 +191,7 @@ class FlopEstimator : private ExprFunctor, TResult VisitExpr_(const FloatImmNode* op) override { return TResult(); } TResult VisitExpr_(const CastNode* op) override { return VisitExpr(op->value); } TResult VisitStmt_(const AllocateConstNode* op) override { return VisitStmt(op->body); } + TResult VisitStmt_(const AllocateNode* op) override { return VisitStmt(op->body); } TResult VisitStmt_(const DeclBufferNode* op) override { return VisitStmt(op->body); } TResult VisitStmt_(const SeqStmtNode* seq) override {