@@ -626,12 +626,92 @@ class ReplaceSymbolicVisitor : public PassUtils::PassVisitor<ReplaceSymbolicVisi
626626 return module_scope->get_symbol (name);
627627 }
628628
629+ ASR::expr_t * process_attributes (Allocator &al, const Location &loc, ASR::expr_t * expr,
630+ SymbolTable* module_scope) {
631+ if (ASR::is_a<ASR::IntrinsicScalarFunction_t>(*expr)) {
632+ ASR::IntrinsicScalarFunction_t* intrinsic_func = ASR::down_cast<ASR::IntrinsicScalarFunction_t>(expr);
633+ int64_t intrinsic_id = intrinsic_func->m_intrinsic_id ;
634+ switch (static_cast <LCompilers::ASRUtils::IntrinsicScalarFunctions>(intrinsic_id)) {
635+ case LCompilers::ASRUtils::IntrinsicScalarFunctions::SymbolicHasSymbolQ: {
636+ std::string name = " basic_has_symbol" ;
637+ symbolic_dependencies.push_back (name);
638+ if (!module_scope->get_symbol (name)) {
639+ std::string header = " symengine/cwrapper.h" ;
640+ SymbolTable* fn_symtab = al.make_new <SymbolTable>(module_scope);
641+
642+ Vec<ASR::expr_t *> args;
643+ args.reserve (al, 1 );
644+ ASR::symbol_t * arg1 = ASR::down_cast<ASR::symbol_t >(ASR::make_Variable_t (
645+ al, loc, fn_symtab, s2c (al, " _lpython_return_variable" ), nullptr , 0 , ASR::intentType::ReturnVar,
646+ nullptr , nullptr , ASR::storage_typeType::Default, ASRUtils::TYPE (ASR::make_Logical_t (al, loc, 4 )),
647+ nullptr , ASR::abiType::BindC, ASR::Public, ASR::presenceType::Required, false ));
648+ fn_symtab->add_symbol (s2c (al, " _lpython_return_variable" ), arg1);
649+ ASR::symbol_t * arg2 = ASR::down_cast<ASR::symbol_t >(ASR::make_Variable_t (
650+ al, loc, fn_symtab, s2c (al, " x" ), nullptr , 0 , ASR::intentType::In,
651+ nullptr , nullptr , ASR::storage_typeType::Default, ASRUtils::TYPE (ASR::make_CPtr_t (al, loc)),
652+ nullptr , ASR::abiType::BindC, ASR::Public, ASR::presenceType::Required, true ));
653+ fn_symtab->add_symbol (s2c (al, " x" ), arg2);
654+ args.push_back (al, ASRUtils::EXPR (ASR::make_Var_t (al, loc, arg2)));
655+ ASR::symbol_t * arg3 = ASR::down_cast<ASR::symbol_t >(ASR::make_Variable_t (
656+ al, loc, fn_symtab, s2c (al, " y" ), nullptr , 0 , ASR::intentType::In,
657+ nullptr , nullptr , ASR::storage_typeType::Default, ASRUtils::TYPE (ASR::make_CPtr_t (al, loc)),
658+ nullptr , ASR::abiType::BindC, ASR::Public, ASR::presenceType::Required, true ));
659+ fn_symtab->add_symbol (s2c (al, " y" ), arg3);
660+ args.push_back (al, ASRUtils::EXPR (ASR::make_Var_t (al, loc, arg3)));
661+
662+ Vec<ASR::stmt_t *> body;
663+ body.reserve (al, 1 );
664+
665+ Vec<char *> dep;
666+ dep.reserve (al, 1 );
667+
668+ ASR::expr_t * return_var = ASRUtils::EXPR (ASR::make_Var_t (al, loc, fn_symtab->get_symbol (" _lpython_return_variable" )));
669+ ASR::asr_t * subrout = ASRUtils::make_Function_t_util (al, loc,
670+ fn_symtab, s2c (al, name), dep.p , dep.n , args.p , args.n , body.p , body.n ,
671+ return_var, ASR::abiType::BindC, ASR::accessType::Public,
672+ ASR::deftypeType::Interface, s2c (al, name), false , false , false ,
673+ false , false , nullptr , 0 , false , false , false , s2c (al, header));
674+ ASR::symbol_t * symbol = ASR::down_cast<ASR::symbol_t >(subrout);
675+ module_scope->add_symbol (s2c (al, name), symbol);
676+ }
677+
678+ ASR::symbol_t * basic_has_symbol = module_scope->get_symbol (name);
679+ ASR::expr_t * value1 = handle_argument (al, loc, intrinsic_func->m_args [0 ]);
680+ ASR::expr_t * value2 = handle_argument (al, loc, intrinsic_func->m_args [1 ]);
681+ Vec<ASR::call_arg_t > call_args;
682+ call_args.reserve (al, 1 );
683+ ASR::call_arg_t call_arg1, call_arg2;
684+ call_arg1.loc = loc;
685+ call_arg1.m_value = value1;
686+ call_args.push_back (al, call_arg1);
687+ call_arg2.loc = loc;
688+ call_arg2.m_value = value2;
689+ call_args.push_back (al, call_arg2);
690+ return ASRUtils::EXPR (ASRUtils::make_FunctionCall_t_util (al, loc,
691+ basic_has_symbol, basic_has_symbol, call_args.p , call_args.n ,
692+ ASRUtils::TYPE (ASR::make_Logical_t (al, loc, 4 )), nullptr , nullptr ));
693+ break ;
694+ }
695+ default : {
696+ throw LCompilersException (" IntrinsicFunction: `"
697+ + ASRUtils::get_intrinsic_name (intrinsic_id)
698+ + " ` is not implemented" );
699+ }
700+ }
701+ }
702+ return expr;
703+ }
704+
629705 void visit_Assignment (const ASR::Assignment_t &x) {
630706 SymbolTable* module_scope = current_scope->parent ;
631707 if (ASR::is_a<ASR::IntrinsicScalarFunction_t>(*x.m_value )) {
632708 ASR::IntrinsicScalarFunction_t* intrinsic_func = ASR::down_cast<ASR::IntrinsicScalarFunction_t>(x.m_value );
633709 if (intrinsic_func->m_type ->type == ASR::ttypeType::SymbolicExpression) {
634710 process_intrinsic_function (al, x.base .base .loc , intrinsic_func, module_scope, x.m_target );
711+ } else if (intrinsic_func->m_type ->type == ASR::ttypeType::Logical) {
712+ ASR::expr_t * function_call = process_attributes (al, x.base .base .loc , x.m_value , module_scope);
713+ ASR::stmt_t * stmt = ASRUtils::STMT (ASR::make_Assignment_t (al, x.base .base .loc , x.m_target , function_call, nullptr ));
714+ pass_result.push_back (al, stmt);
635715 }
636716 } else if (ASR::is_a<ASR::Cast_t>(*x.m_value )) {
637717 ASR::Cast_t* cast_t = ASR::down_cast<ASR::Cast_t>(x.m_value );
@@ -770,37 +850,42 @@ class ReplaceSymbolicVisitor : public PassUtils::PassVisitor<ReplaceSymbolicVisi
770850 basic_str_sym, basic_str_sym, call_args.p , call_args.n ,
771851 ASRUtils::TYPE (ASR::make_Character_t (al, x.base .base .loc , 1 , -2 , nullptr )), nullptr , nullptr ));
772852 print_tmp.push_back (function_call);
773- } else if (ASR::is_a<ASR::IntrinsicScalarFunction_t>(*val) && ASR::is_a<ASR::SymbolicExpression_t>(* ASRUtils::expr_type (val)) ) {
853+ } else if (ASR::is_a<ASR::IntrinsicScalarFunction_t>(*val)) {
774854 ASR::IntrinsicScalarFunction_t* intrinsic_func = ASR::down_cast<ASR::IntrinsicScalarFunction_t>(val);
775- ASR::ttype_t *type = ASRUtils::TYPE (ASR::make_SymbolicExpression_t (al, x.base .base .loc ));
776- std::string symengine_var = symengine_stack.push ();
777- ASR::symbol_t *arg = ASR::down_cast<ASR::symbol_t >(ASR::make_Variable_t (
778- al, x.base .base .loc , current_scope, s2c (al, symengine_var), nullptr , 0 , ASR::intentType::Local,
779- nullptr , nullptr , ASR::storage_typeType::Default, type, nullptr ,
780- ASR::abiType::BindC, ASR::Public, ASR::presenceType::Required, false ));
781- current_scope->add_symbol (s2c (al, symengine_var), arg);
782- for (auto &item : current_scope->get_scope ()) {
783- if (ASR::is_a<ASR::Variable_t>(*item.second )) {
784- ASR::Variable_t *s = ASR::down_cast<ASR::Variable_t>(item.second );
785- this ->visit_Variable (*s);
855+ if (ASR::is_a<ASR::SymbolicExpression_t>(*ASRUtils::expr_type (val))) {
856+ ASR::ttype_t *type = ASRUtils::TYPE (ASR::make_SymbolicExpression_t (al, x.base .base .loc ));
857+ std::string symengine_var = symengine_stack.push ();
858+ ASR::symbol_t *arg = ASR::down_cast<ASR::symbol_t >(ASR::make_Variable_t (
859+ al, x.base .base .loc , current_scope, s2c (al, symengine_var), nullptr , 0 , ASR::intentType::Local,
860+ nullptr , nullptr , ASR::storage_typeType::Default, type, nullptr ,
861+ ASR::abiType::BindC, ASR::Public, ASR::presenceType::Required, false ));
862+ current_scope->add_symbol (s2c (al, symengine_var), arg);
863+ for (auto &item : current_scope->get_scope ()) {
864+ if (ASR::is_a<ASR::Variable_t>(*item.second )) {
865+ ASR::Variable_t *s = ASR::down_cast<ASR::Variable_t>(item.second );
866+ this ->visit_Variable (*s);
867+ }
786868 }
787- }
788869
789- ASR::expr_t * target = ASRUtils::EXPR (ASR::make_Var_t (al, x.base .base .loc , arg));
790- process_intrinsic_function (al, x.base .base .loc , intrinsic_func, module_scope, target);
870+ ASR::expr_t * target = ASRUtils::EXPR (ASR::make_Var_t (al, x.base .base .loc , arg));
871+ process_intrinsic_function (al, x.base .base .loc , intrinsic_func, module_scope, target);
791872
792- // Now create the FunctionCall node for basic_str
793- ASR::symbol_t * basic_str_sym = declare_basic_str_function (al, x.base .base .loc , module_scope);
794- Vec<ASR::call_arg_t > call_args;
795- call_args.reserve (al, 1 );
796- ASR::call_arg_t call_arg;
797- call_arg.loc = x.base .base .loc ;
798- call_arg.m_value = target;
799- call_args.push_back (al, call_arg);
800- ASR::expr_t * function_call = ASRUtils::EXPR (ASRUtils::make_FunctionCall_t_util (al, x.base .base .loc ,
801- basic_str_sym, basic_str_sym, call_args.p , call_args.n ,
802- ASRUtils::TYPE (ASR::make_Character_t (al, x.base .base .loc , 1 , -2 , nullptr )), nullptr , nullptr ));
803- print_tmp.push_back (function_call);
873+ // Now create the FunctionCall node for basic_str
874+ ASR::symbol_t * basic_str_sym = declare_basic_str_function (al, x.base .base .loc , module_scope);
875+ Vec<ASR::call_arg_t > call_args;
876+ call_args.reserve (al, 1 );
877+ ASR::call_arg_t call_arg;
878+ call_arg.loc = x.base .base .loc ;
879+ call_arg.m_value = target;
880+ call_args.push_back (al, call_arg);
881+ ASR::expr_t * function_call = ASRUtils::EXPR (ASRUtils::make_FunctionCall_t_util (al, x.base .base .loc ,
882+ basic_str_sym, basic_str_sym, call_args.p , call_args.n ,
883+ ASRUtils::TYPE (ASR::make_Character_t (al, x.base .base .loc , 1 , -2 , nullptr )), nullptr , nullptr ));
884+ print_tmp.push_back (function_call);
885+ } else if (ASR::is_a<ASR::Logical_t>(*ASRUtils::expr_type (val))) {
886+ ASR::expr_t * function_call = process_attributes (al, x.base .base .loc , val, module_scope);
887+ print_tmp.push_back (function_call);
888+ }
804889 } else if (ASR::is_a<ASR::Cast_t>(*val)) {
805890 ASR::Cast_t* cast_t = ASR::down_cast<ASR::Cast_t>(val);
806891 if (cast_t ->m_kind != ASR::cast_kindType::IntegerToSymbolicExpression) return ;
@@ -951,20 +1036,34 @@ class ReplaceSymbolicVisitor : public PassUtils::PassVisitor<ReplaceSymbolicVisi
9511036 }
9521037
9531038 void visit_Assert (const ASR::Assert_t &x) {
954- if (!ASR::is_a<ASR::SymbolicCompare_t>(*x.m_test )) return ;
955- ASR::SymbolicCompare_t *s = ASR::down_cast<ASR::SymbolicCompare_t>(x.m_test );
9561039 SymbolTable* module_scope = current_scope->parent ;
9571040 ASR::expr_t * left_tmp = nullptr ;
9581041 ASR::expr_t * right_tmp = nullptr ;
1042+ if (ASR::is_a<ASR::LogicalCompare_t>(*x.m_test )) {
1043+ ASR::LogicalCompare_t *l = ASR::down_cast<ASR::LogicalCompare_t>(x.m_test );
1044+
1045+ left_tmp = process_attributes (al, x.base .base .loc , l->m_left , module_scope);
1046+ right_tmp = process_attributes (al, x.base .base .loc , l->m_right , module_scope);
1047+ ASR::expr_t * test = ASRUtils::EXPR (ASR::make_LogicalCompare_t (al, x.base .base .loc , left_tmp,
1048+ l->m_op , right_tmp, l->m_type , l->m_value ));
1049+
1050+ ASR::stmt_t *assert_stmt = ASRUtils::STMT (ASR::make_Assert_t (al, x.base .base .loc , test, x.m_msg ));
1051+ pass_result.push_back (al, assert_stmt);
1052+ } else if (ASR::is_a<ASR::SymbolicCompare_t>(*x.m_test )) {
1053+ ASR::SymbolicCompare_t *s = ASR::down_cast<ASR::SymbolicCompare_t>(x.m_test );
1054+ SymbolTable* module_scope = current_scope->parent ;
1055+ ASR::expr_t * left_tmp = nullptr ;
1056+ ASR::expr_t * right_tmp = nullptr ;
9591057
960- ASR::symbol_t * basic_str_sym = declare_basic_str_function (al, x.base .base .loc , module_scope);
961- left_tmp = process_with_basic_str (al, x.base .base .loc , s->m_left , basic_str_sym);
962- right_tmp = process_with_basic_str (al, x.base .base .loc , s->m_right , basic_str_sym);
963- ASR::expr_t * test = ASRUtils::EXPR (ASR::make_StringCompare_t (al, x.base .base .loc , left_tmp,
964- s->m_op , right_tmp, s->m_type , s->m_value ));
1058+ ASR::symbol_t * basic_str_sym = declare_basic_str_function (al, x.base .base .loc , module_scope);
1059+ left_tmp = process_with_basic_str (al, x.base .base .loc , s->m_left , basic_str_sym);
1060+ right_tmp = process_with_basic_str (al, x.base .base .loc , s->m_right , basic_str_sym);
1061+ ASR::expr_t * test = ASRUtils::EXPR (ASR::make_StringCompare_t (al, x.base .base .loc , left_tmp,
1062+ s->m_op , right_tmp, s->m_type , s->m_value ));
9651063
966- ASR::stmt_t *assert_stmt = ASRUtils::STMT (ASR::make_Assert_t (al, x.base .base .loc , test, x.m_msg ));
967- pass_result.push_back (al, assert_stmt);
1064+ ASR::stmt_t *assert_stmt = ASRUtils::STMT (ASR::make_Assert_t (al, x.base .base .loc , test, x.m_msg ));
1065+ pass_result.push_back (al, assert_stmt);
1066+ }
9681067 }
9691068};
9701069
0 commit comments