@@ -139,6 +139,7 @@ ASR::Module_t* load_module(Allocator &al, SymbolTable *symtab,
139139 return mod2;
140140}
141141
142+
142143template <class Derived >
143144class CommonVisitor : public AST ::BaseVisitor<Derived> {
144145public:
@@ -156,10 +157,13 @@ class CommonVisitor : public AST::BaseVisitor<Derived> {
156157 // The main module is stored directly in TranslationUnit, other modules are Modules
157158 bool main_module;
158159 PythonIntrinsicProcedures intrinsic_procedures;
160+ std::map<int , ASR::symbol_t *> &ast_overload;
159161
160162 CommonVisitor (Allocator &al, SymbolTable *symbol_table,
161- diag::Diagnostics &diagnostics, bool main_module)
162- : diag{diagnostics}, al{al}, current_scope{symbol_table}, main_module{main_module} {
163+ diag::Diagnostics &diagnostics, bool main_module,
164+ std::map<int , ASR::symbol_t *> &ast_overload)
165+ : diag{diagnostics}, al{al}, current_scope{symbol_table}, main_module{main_module},
166+ ast_overload{ast_overload} {
163167 current_module_dependencies.reserve (al, 4 );
164168 }
165169
@@ -445,7 +449,7 @@ ASR::symbol_t* import_from_module(Allocator &al, ASR::Module_t *m, SymbolTable *
445449 throw SemanticError (" Only Subroutines, Functions and Variables are currently supported in 'import'" ,
446450 loc);
447451 }
448- // should not reach here
452+ LFORTRAN_ASSERT ( false );
449453 return nullptr ;
450454}
451455
@@ -469,11 +473,13 @@ class SymbolTableVisitor : public CommonVisitor<SymbolTableVisitor> {
469473 std::map<SymbolTable*, ASR::accessType> assgn;
470474 ASR::symbol_t *current_module_sym;
471475 std::vector<std::string> excluded_from_symtab;
476+ std::map<std::string, Vec<ASR::symbol_t * >> overload_defs;
472477
473478
474479 SymbolTableVisitor (Allocator &al, SymbolTable *symbol_table,
475- diag::Diagnostics &diagnostics, bool main_module)
476- : CommonVisitor(al, symbol_table, diagnostics, main_module), is_derived_type{false } {}
480+ diag::Diagnostics &diagnostics, bool main_module,
481+ std::map<int , ASR::symbol_t *> &ast_overload)
482+ : CommonVisitor(al, symbol_table, diagnostics, main_module, ast_overload), is_derived_type{false } {}
477483
478484
479485 ASR::symbol_t * resolve_symbol (const Location &loc, const std::string &sub_name) {
@@ -522,7 +528,9 @@ class SymbolTableVisitor : public CommonVisitor<SymbolTableVisitor> {
522528 for (size_t i=0 ; i<x.n_body ; i++) {
523529 visit_stmt (*x.m_body [i]);
524530 }
525-
531+ if (!overload_defs.empty ()) {
532+ create_GenericProcedure (x.base .base .loc );
533+ }
526534 global_scope = nullptr ;
527535 tmp = tmp0;
528536 }
@@ -534,12 +542,23 @@ class SymbolTableVisitor : public CommonVisitor<SymbolTableVisitor> {
534542 Vec<ASR::expr_t *> args;
535543 args.reserve (al, x.m_args .n_args );
536544 current_procedure_abi_type = ASR::abiType::Source;
537- if (x.n_decorator_list == 1 ) {
538- AST::expr_t *dec = x.m_decorator_list [0 ];
539- if (AST::is_a<AST::Name_t>(*dec)) {
540- std::string name = AST::down_cast<AST::Name_t>(dec)->m_id ;
541- if (name == " ccall" ) {
542- current_procedure_abi_type = ASR::abiType::BindC;
545+ bool overload = false ;
546+ if (x.n_decorator_list > 0 ) {
547+ for (size_t i=0 ; i<x.n_decorator_list ; i++) {
548+ AST::expr_t *dec = x.m_decorator_list [i];
549+ if (AST::is_a<AST::Name_t>(*dec)) {
550+ std::string name = AST::down_cast<AST::Name_t>(dec)->m_id ;
551+ if (name == " ccall" ) {
552+ current_procedure_abi_type = ASR::abiType::BindC;
553+ } else if (name == " overload" ) {
554+ overload = true ;
555+ } else {
556+ throw SemanticError (" Decorator: " + name + " is not supported" ,
557+ x.base .base .loc );
558+ }
559+ } else {
560+ throw SemanticError (" Unsupported Decorator type" ,
561+ x.base .base .loc );
543562 }
544563 }
545564 }
@@ -578,6 +597,18 @@ class SymbolTableVisitor : public CommonVisitor<SymbolTableVisitor> {
578597 var)));
579598 }
580599 std::string sym_name = x.m_name ;
600+ if (overload) {
601+ std::string overload_number;
602+ if (overload_defs.find (sym_name) == overload_defs.end ()){
603+ overload_number = " 0" ;
604+ Vec<ASR::symbol_t *> v;
605+ v.reserve (al, 1 );
606+ overload_defs[sym_name] = v;
607+ } else {
608+ overload_number = std::to_string (overload_defs[sym_name].size ());
609+ }
610+ sym_name = " __lpython_overloaded_" + overload_number + " __" + sym_name;
611+ }
581612 if (parent_scope->scope .find (sym_name) != parent_scope->scope .end ()) {
582613 throw SemanticError (" Subroutine already defined" , tmp->loc );
583614 }
@@ -631,8 +662,23 @@ class SymbolTableVisitor : public CommonVisitor<SymbolTableVisitor> {
631662 s_access, deftype, bindc_name,
632663 is_pure, is_module);
633664 }
634- parent_scope->scope [sym_name] = ASR::down_cast<ASR::symbol_t >(tmp);
665+ ASR::symbol_t * t = ASR::down_cast<ASR::symbol_t >(tmp);
666+ parent_scope->scope [sym_name] = t;
635667 current_scope = parent_scope;
668+ if (overload) {
669+ overload_defs[x.m_name ].push_back (al, t);
670+ ast_overload[(int64_t )&x] = t;
671+ }
672+ }
673+
674+ void create_GenericProcedure (const Location &loc) {
675+ for (auto &p: overload_defs) {
676+ std::string def_name = p.first ;
677+ tmp = ASR::make_GenericProcedure_t (al, loc, current_scope, s2c (al, def_name),
678+ p.second .p , p.second .size (), ASR::accessType::Public);
679+ ASR::symbol_t *t = ASR::down_cast<ASR::symbol_t >(tmp);
680+ current_scope->scope [def_name] = t;
681+ }
636682 }
637683
638684 void visit_ImportFrom (const AST::ImportFrom_t &x) {
@@ -724,9 +770,10 @@ class SymbolTableVisitor : public CommonVisitor<SymbolTableVisitor> {
724770};
725771
726772Result<ASR::asr_t *> symbol_table_visitor (Allocator &al, const AST::Module_t &ast,
727- diag::Diagnostics &diagnostics, bool main_module)
773+ diag::Diagnostics &diagnostics, bool main_module,
774+ std::map<int , ASR::symbol_t *> &ast_overload)
728775{
729- SymbolTableVisitor v (al, nullptr , diagnostics, main_module);
776+ SymbolTableVisitor v (al, nullptr , diagnostics, main_module, ast_overload );
730777 try {
731778 v.visit_Module (ast);
732779 } catch (const SemanticError &e) {
@@ -748,8 +795,9 @@ class BodyVisitor : public CommonVisitor<BodyVisitor> {
748795 ASR::asr_t *asr;
749796 Vec<ASR::stmt_t *> *current_body;
750797
751- BodyVisitor (Allocator &al, ASR::asr_t *unit, diag::Diagnostics &diagnostics, bool main_module)
752- : CommonVisitor(al, nullptr , diagnostics, main_module), asr{unit} {}
798+ BodyVisitor (Allocator &al, ASR::asr_t *unit, diag::Diagnostics &diagnostics,
799+ bool main_module, std::map<int , ASR::symbol_t *> &ast_overload)
800+ : CommonVisitor(al, nullptr , diagnostics, main_module, ast_overload), asr{unit} {}
753801
754802 // Transforms statements to a list of ASR statements
755803 // In addition, it also inserts the following nodes if needed:
@@ -817,6 +865,16 @@ class BodyVisitor : public CommonVisitor<BodyVisitor> {
817865 } else if (ASR::is_a<ASR::Function_t>(*t)) {
818866 ASR::Function_t *f = ASR::down_cast<ASR::Function_t>(t);
819867 handle_fn (x, *f);
868+ } else if (ASR::is_a<ASR::GenericProcedure_t>(*t)) {
869+ ASR::symbol_t *s = ast_overload[(int64_t )&x];
870+ if (ASR::is_a<ASR::Subroutine_t>(*s)) {
871+ handle_fn (x, *ASR::down_cast<ASR::Subroutine_t>(s));
872+ } else if (ASR::is_a<ASR::Function_t>(*s)) {
873+ ASR::Function_t *f = ASR::down_cast<ASR::Function_t>(s);
874+ handle_fn (x, *f);
875+ } else {
876+ LFORTRAN_ASSERT (false );
877+ }
820878 } else {
821879 LFORTRAN_ASSERT (false );
822880 }
@@ -2108,8 +2166,15 @@ class BodyVisitor : public CommonVisitor<BodyVisitor> {
21082166 x.base .base .loc );
21092167 }
21102168
2111- ASR::symbol_t *s = current_scope->resolve_symbol (call_name);
2112-
2169+ ASR::symbol_t *s = current_scope->resolve_symbol (call_name), *s_generic = nullptr ;
2170+ if (s!=nullptr && s->type == ASR::symbolType::GenericProcedure) {
2171+ ASR::GenericProcedure_t *p = ASR::down_cast<ASR::GenericProcedure_t>(s);
2172+ int idx = ASRUtils::select_generic_procedure (args, *p, x.base .base .loc ,
2173+ [&](const std::string &msg, const Location &loc) { throw SemanticError (msg, loc); });
2174+ // Create ExternalSymbol for procedures in different modules.
2175+ s_generic = s;
2176+ s = p->m_procs [idx];
2177+ }
21132178
21142179 if (!s) {
21152180 if (intrinsic_procedures.is_intrinsic (call_name)) {
@@ -2246,10 +2311,10 @@ class BodyVisitor : public CommonVisitor<BodyVisitor> {
22462311 value = intrinsic_procedures.comptime_eval (call_name, al, x.base .base .loc , args);
22472312 }
22482313 tmp = ASR::make_FunctionCall_t (al, x.base .base .loc , stemp,
2249- nullptr , args.p , args.size (), nullptr , 0 , a_type, value, nullptr );
2314+ s_generic , args.p , args.size (), nullptr , 0 , a_type, value, nullptr );
22502315 } else if (ASR::is_a<ASR::Subroutine_t>(*s)) {
22512316 tmp = ASR::make_SubroutineCall_t (al, x.base .base .loc , stemp,
2252- nullptr , args.p , args.size (), nullptr );
2317+ s_generic , args.p , args.size (), nullptr );
22532318 } else {
22542319 throw SemanticError (" Unsupported call type for " + call_name,
22552320 x.base .base .loc );
@@ -2265,9 +2330,10 @@ class BodyVisitor : public CommonVisitor<BodyVisitor> {
22652330Result<ASR::TranslationUnit_t*> body_visitor (Allocator &al,
22662331 const AST::Module_t &ast,
22672332 diag::Diagnostics &diagnostics,
2268- ASR::asr_t *unit, bool main_module)
2333+ ASR::asr_t *unit, bool main_module,
2334+ std::map<int , ASR::symbol_t *> &ast_overload)
22692335{
2270- BodyVisitor b (al, unit, diagnostics, main_module);
2336+ BodyVisitor b (al, unit, diagnostics, main_module, ast_overload );
22712337 try {
22722338 b.visit_Module (ast);
22732339 } catch (const SemanticError &e) {
@@ -2301,10 +2367,13 @@ std::string pickle_python(AST::ast_t &ast, bool colors, bool indent) {
23012367Result<ASR::TranslationUnit_t*> python_ast_to_asr (Allocator &al,
23022368 AST::ast_t &ast, diag::Diagnostics &diagnostics, bool main_module)
23032369{
2370+ std::map<int , ASR::symbol_t *> ast_overload;
2371+
23042372 AST::Module_t *ast_m = AST::down_cast2<AST::Module_t>(&ast);
23052373
23062374 ASR::asr_t *unit;
2307- auto res = symbol_table_visitor (al, *ast_m, diagnostics, main_module);
2375+ auto res = symbol_table_visitor (al, *ast_m, diagnostics, main_module,
2376+ ast_overload);
23082377 if (res.ok ) {
23092378 unit = res.result ;
23102379 } else {
@@ -2313,7 +2382,8 @@ Result<ASR::TranslationUnit_t*> python_ast_to_asr(Allocator &al,
23132382 ASR::TranslationUnit_t *tu = ASR::down_cast2<ASR::TranslationUnit_t>(unit);
23142383 LFORTRAN_ASSERT (asr_verify (*tu));
23152384
2316- auto res2 = body_visitor (al, *ast_m, diagnostics, unit, main_module);
2385+ auto res2 = body_visitor (al, *ast_m, diagnostics, unit, main_module,
2386+ ast_overload);
23172387 if (res2.ok ) {
23182388 tu = res2.result ;
23192389 } else {
0 commit comments