@@ -797,9 +797,11 @@ class ExprStmtDuplicatorVisitor(ASDLVisitor):
797797 def __init__ (self , stream , data ):
798798 self .duplicate_stmt = []
799799 self .duplicate_expr = []
800+ self .duplicate_ttype = []
800801 self .duplicate_case_stmt = []
801802 self .is_stmt = False
802803 self .is_expr = False
804+ self .is_ttype = False
803805 self .is_case_stmt = False
804806 self .is_product = False
805807 super (ExprStmtDuplicatorVisitor , self ).__init__ (stream , data )
@@ -834,6 +836,13 @@ def visitModule(self, mod):
834836 self .duplicate_expr .append (("" , 0 ))
835837 self .duplicate_expr .append ((" switch(x->type) {" , 1 ))
836838
839+ self .duplicate_ttype .append ((" ASR::ttype_t* duplicate_ttype(ASR::ttype_t* x) {" , 0 ))
840+ self .duplicate_ttype .append ((" if( !x ) {" , 1 ))
841+ self .duplicate_ttype .append ((" return nullptr;" , 2 ))
842+ self .duplicate_ttype .append ((" }" , 1 ))
843+ self .duplicate_ttype .append (("" , 0 ))
844+ self .duplicate_ttype .append ((" switch(x->type) {" , 1 ))
845+
837846 self .duplicate_case_stmt .append ((" ASR::case_stmt_t* duplicate_case_stmt(ASR::case_stmt_t* x) {" , 0 ))
838847 self .duplicate_case_stmt .append ((" if( !x ) {" , 1 ))
839848 self .duplicate_case_stmt .append ((" return nullptr;" , 2 ))
@@ -858,6 +867,14 @@ def visitModule(self, mod):
858867 self .duplicate_expr .append ((" return nullptr;" , 1 ))
859868 self .duplicate_expr .append ((" }" , 0 ))
860869
870+ self .duplicate_ttype .append ((" default: {" , 2 ))
871+ self .duplicate_ttype .append ((' LCOMPILERS_ASSERT_MSG(false, "Duplication of " + std::to_string(x->type) + " type is not supported yet.");' , 3 ))
872+ self .duplicate_ttype .append ((" }" , 2 ))
873+ self .duplicate_ttype .append ((" }" , 1 ))
874+ self .duplicate_ttype .append (("" , 0 ))
875+ self .duplicate_ttype .append ((" return nullptr;" , 1 ))
876+ self .duplicate_ttype .append ((" }" , 0 ))
877+
861878 self .duplicate_case_stmt .append ((" default: {" , 2 ))
862879 self .duplicate_case_stmt .append ((' LCOMPILERS_ASSERT_MSG(false, "Duplication of " + std::to_string(x->type) + " case statement is not supported yet.");' , 3 ))
863880 self .duplicate_case_stmt .append ((" }" , 2 ))
@@ -872,6 +889,9 @@ def visitModule(self, mod):
872889 for line , level in self .duplicate_expr :
873890 self .emit (line , level = level )
874891 self .emit ("" )
892+ for line , level in self .duplicate_ttype :
893+ self .emit (line , level = level )
894+ self .emit ("" )
875895 for line , level in self .duplicate_case_stmt :
876896 self .emit (line , level = level )
877897 self .emit ("" )
@@ -885,8 +905,9 @@ def visitType(self, tp):
885905 def visitSum (self , sum , * args ):
886906 self .is_stmt = args [0 ] == 'stmt'
887907 self .is_expr = args [0 ] == 'expr'
908+ self .is_ttype = args [0 ] == "ttype"
888909 self .is_case_stmt = args [0 ] == 'case_stmt'
889- if self .is_stmt or self .is_expr or self .is_case_stmt :
910+ if self .is_stmt or self .is_expr or self .is_case_stmt or self . is_ttype :
890911 for tp in sum .types :
891912 self .visit (tp , * args )
892913
@@ -933,6 +954,10 @@ def make_visitor(self, name, fields):
933954 self .duplicate_expr .append ((" }" , 3 ))
934955 self .duplicate_expr .append ((" return down_cast<ASR::expr_t>(self().duplicate_%s(down_cast<ASR::%s_t>(x)));" % (name , name ), 3 ))
935956 self .duplicate_expr .append ((" }" , 2 ))
957+ elif self .is_ttype :
958+ self .duplicate_ttype .append ((" case ASR::ttypeType::%s: {" % name , 2 ))
959+ self .duplicate_ttype .append ((" return down_cast<ASR::ttype_t>(self().duplicate_%s(down_cast<ASR::%s_t>(x)));" % (name , name ), 3 ))
960+ self .duplicate_ttype .append ((" }" , 2 ))
936961 elif self .is_case_stmt :
937962 self .duplicate_case_stmt .append ((" case ASR::case_stmtType::%s: {" % name , 2 ))
938963 self .duplicate_case_stmt .append ((" return down_cast<ASR::case_stmt_t>(self().duplicate_%s(down_cast<ASR::%s_t>(x)));" % (name , name ), 3 ))
@@ -949,7 +974,8 @@ def visitField(self, field):
949974 field .type == "do_loop_head" or
950975 field .type == "array_index" or
951976 field .type == "alloc_arg" or
952- field .type == "case_stmt" ):
977+ field .type == "case_stmt" or
978+ field .type == "ttype" ):
953979 level = 2
954980 if field .seq :
955981 self .used = True
@@ -1107,10 +1133,12 @@ def visitField(self, field):
11071133 self .used = True
11081134 self .emit ("for (size_t i = 0; i < x->n_%s; i++) {" % field .name , level )
11091135 if field .type == "call_arg" :
1110- self .emit (" ASR::expr_t** current_expr_copy_%d = current_expr;" % (self .current_expr_copy_variable_count ), level )
1111- self .emit (" current_expr = &(x->m_%s[i].m_value);" % (field .name ), level )
1112- self .emit (" self().replace_expr(x->m_%s[i].m_value);" % (field .name ), level )
1113- self .emit (" current_expr = current_expr_copy_%d;" % (self .current_expr_copy_variable_count ), level )
1136+ self .emit (" if (x->m_%s[i].m_value != nullptr) {" % (field .name ), level )
1137+ self .emit (" ASR::expr_t** current_expr_copy_%d = current_expr;" % (self .current_expr_copy_variable_count ), level + 1 )
1138+ self .emit (" current_expr = &(x->m_%s[i].m_value);" % (field .name ), level + 1 )
1139+ self .emit (" self().replace_expr(x->m_%s[i].m_value);" % (field .name ), level + 1 )
1140+ self .emit (" current_expr = current_expr_copy_%d;" % (self .current_expr_copy_variable_count ), level + 1 )
1141+ self .emit (" }" , level )
11141142 self .current_expr_copy_variable_count += 1
11151143 self .emit ("}" , level )
11161144 else :
@@ -2310,6 +2338,8 @@ def make_visitor(self, name, fields):
23102338 LCOMPILERS_ASSERT(e->m_external);
23112339 LCOMPILERS_ASSERT(!ASR::is_a<ASR::ExternalSymbol_t>(*e->m_external));
23122340 s = e->m_external;
2341+ } else if (s->type == ASR::symbolType::Function) {
2342+ return ASR::down_cast<ASR::Function_t>(s)->m_function_signature;
23132343 }
23142344 return ASR::down_cast<ASR::Variable_t>(s)->m_type;
23152345 }""" \
@@ -2529,6 +2559,9 @@ def main(argv):
25292559 subs ["MOD" ] = "LPython::AST"
25302560 subs ["mod" ] = "ast"
25312561 subs ["lcompiler" ] = "lpython"
2562+ elif subs ["MOD" ] == "AST" :
2563+ subs ["MOD" ] = "LFortran::AST"
2564+ subs ["lcompiler" ] = "lfortran"
25322565 else :
25332566 subs ["lcompiler" ] = "lfortran"
25342567 is_asr = (mod .name .upper () == "ASR" )
0 commit comments