From d1d93ef89ed83c1e9def6171471da8455c69a779 Mon Sep 17 00:00:00 2001 From: Andrew Luo Date: Fri, 7 Jan 2022 13:49:46 -0800 Subject: [PATCH] dynamic to static use infer_type_local --- src/relay/transforms/dynamic_to_static.cc | 14 ++++++++++++-- 1 file changed, 12 insertions(+), 2 deletions(-) diff --git a/src/relay/transforms/dynamic_to_static.cc b/src/relay/transforms/dynamic_to_static.cc index f3c53cfc8bc0..bafdbd359141 100644 --- a/src/relay/transforms/dynamic_to_static.cc +++ b/src/relay/transforms/dynamic_to_static.cc @@ -240,6 +240,14 @@ class DynamicToStaticMutator : public MixedModeMutator { gv_ = vars[func_]; } + Expr GetCurExpr(const Expr& original_expr) { + if (original_expr.as()) { + return mod_->Lookup(gv_); + } else { + return mod_->Lookup(gv_).as()->body; + } + } + Expr PrepareInput(const Expr& expr) { BaseFunc func; if (auto* func_node = expr.as()) { @@ -249,10 +257,12 @@ class DynamicToStaticMutator : public MixedModeMutator { relay::Function(relay::FreeVars(expr), expr, Type(), relay::FreeTypeVars(expr, mod_), {}); } mod_->Update(gv_, func); + mod_ = transform::FoldConstant()(mod_); - mod_ = transform::InferType()(mod_); + transform::InferTypeLocal(GetCurExpr(expr)); mod_ = transform::FoldConstant()(mod_); - mod_ = transform::InferType()(mod_); + transform::InferTypeLocal(GetCurExpr(expr)); + Expr out; if (expr.as()) { out = mod_->Lookup(gv_);