@@ -266,6 +266,25 @@ class ReplaceSymbolicVisitor : public PassUtils::PassVisitor<ReplaceSymbolicVisi
266266 {handle_argument (al, loc, value_01), handle_argument (al, loc, value_02)},
267267 ASRUtils::TYPE (ASR::make_Logical_t (al, loc, 4 )));
268268 }
269+
270+ static inline bool is_logical_intrinsic_symbolic (ASR::expr_t * expr) {
271+ if (ASR::is_a<ASR::IntrinsicScalarFunction_t>(*expr)) {
272+ ASR::IntrinsicScalarFunction_t* intrinsic_func = ASR::down_cast<ASR::IntrinsicScalarFunction_t>(expr);
273+ int64_t intrinsic_id = intrinsic_func->m_intrinsic_id ;
274+ switch (static_cast <LCompilers::ASRUtils::IntrinsicScalarFunctions>(intrinsic_id)) {
275+ case LCompilers::ASRUtils::IntrinsicScalarFunctions::SymbolicHasSymbolQ:
276+ case LCompilers::ASRUtils::IntrinsicScalarFunctions::SymbolicAddQ:
277+ case LCompilers::ASRUtils::IntrinsicScalarFunctions::SymbolicMulQ:
278+ case LCompilers::ASRUtils::IntrinsicScalarFunctions::SymbolicPowQ:
279+ case LCompilers::ASRUtils::IntrinsicScalarFunctions::SymbolicLogQ:
280+ case LCompilers::ASRUtils::IntrinsicScalarFunctions::SymbolicSinQ:
281+ return true ;
282+ default :
283+ return false ;
284+ }
285+ }
286+ return true ;
287+ }
269288 /* ********************************* Utils *********************************/
270289
271290 void visit_Function (const ASR::Function_t &x) {
@@ -514,9 +533,11 @@ class ReplaceSymbolicVisitor : public PassUtils::PassVisitor<ReplaceSymbolicVisi
514533 if (intrinsic_func->m_type ->type == ASR::ttypeType::SymbolicExpression) {
515534 process_intrinsic_function (x.base .base .loc , intrinsic_func, x.m_target );
516535 } else if (intrinsic_func->m_type ->type == ASR::ttypeType::Logical) {
517- ASR::expr_t * function_call = process_attributes (x.base .base .loc , x.m_value );
518- ASR::stmt_t * stmt = ASRUtils::STMT (ASR::make_Assignment_t (al, x.base .base .loc , x.m_target , function_call, nullptr ));
519- pass_result.push_back (al, stmt);
536+ if (is_logical_intrinsic_symbolic (x.m_value )) {
537+ ASR::expr_t * function_call = process_attributes (x.base .base .loc , x.m_value );
538+ ASR::stmt_t * stmt = ASRUtils::STMT (ASR::make_Assignment_t (al, x.base .base .loc , x.m_target , function_call, nullptr ));
539+ pass_result.push_back (al, stmt);
540+ }
520541 }
521542 } else if (ASR::is_a<ASR::Cast_t>(*x.m_value )) {
522543 ASR::Cast_t* cast_t = ASR::down_cast<ASR::Cast_t>(x.m_value );
@@ -676,18 +697,22 @@ class ReplaceSymbolicVisitor : public PassUtils::PassVisitor<ReplaceSymbolicVisi
676697 if (ASR::is_a<ASR::IntrinsicScalarFunction_t>(*xx.m_test )) {
677698 ASR::IntrinsicScalarFunction_t* intrinsic_func = ASR::down_cast<ASR::IntrinsicScalarFunction_t>(xx.m_test );
678699 if (intrinsic_func->m_type ->type == ASR::ttypeType::Logical) {
679- ASR::expr_t * function_call = process_attributes (xx.base .base .loc , xx.m_test );
680- xx.m_test = function_call;
700+ if (is_logical_intrinsic_symbolic (xx.m_test )) {
701+ ASR::expr_t * function_call = process_attributes (xx.base .base .loc , xx.m_test );
702+ xx.m_test = function_call;
703+ }
681704 }
682705 } else if (ASR::is_a<ASR::LogicalNot_t>(*xx.m_test )) {
683706 ASR::LogicalNot_t* logical_not = ASR::down_cast<ASR::LogicalNot_t>(xx.m_test );
684707 if (ASR::is_a<ASR::IntrinsicScalarFunction_t>(*logical_not->m_arg )) {
685708 ASR::IntrinsicScalarFunction_t* intrinsic_func = ASR::down_cast<ASR::IntrinsicScalarFunction_t>(logical_not->m_arg );
686709 if (intrinsic_func->m_type ->type == ASR::ttypeType::Logical) {
687- ASR::expr_t * function_call = process_attributes (xx.base .base .loc , logical_not->m_arg );
688- ASR::expr_t * new_logical_not = ASRUtils::EXPR (ASR::make_LogicalNot_t (al, xx.base .base .loc , function_call,
689- logical_not->m_type , logical_not->m_value ));
690- xx.m_test = new_logical_not;
710+ if (is_logical_intrinsic_symbolic (logical_not->m_arg )) {
711+ ASR::expr_t * function_call = process_attributes (xx.base .base .loc , logical_not->m_arg );
712+ ASR::expr_t * new_logical_not = ASRUtils::EXPR (ASR::make_LogicalNot_t (al, xx.base .base .loc , function_call,
713+ logical_not->m_type , logical_not->m_value ));
714+ xx.m_test = new_logical_not;
715+ }
691716 }
692717 }
693718 } else if (ASR::is_a<ASR::SymbolicCompare_t>(*xx.m_test )) {
@@ -784,8 +809,10 @@ class ReplaceSymbolicVisitor : public PassUtils::PassVisitor<ReplaceSymbolicVisi
784809 // Now create the FunctionCall node for basic_str
785810 print_tmp.push_back (basic_str (x.base .base .loc , target));
786811 } else if (ASR::is_a<ASR::Logical_t>(*ASRUtils::expr_type (val))) {
787- ASR::expr_t * function_call = process_attributes (x.base .base .loc , val);
788- print_tmp.push_back (function_call);
812+ if (is_logical_intrinsic_symbolic (val)) {
813+ ASR::expr_t * function_call = process_attributes (x.base .base .loc , val);
814+ print_tmp.push_back (function_call);
815+ }
789816 }
790817 } else if (ASR::is_a<ASR::Cast_t>(*val)) {
791818 ASR::Cast_t* cast_t = ASR::down_cast<ASR::Cast_t>(val);
@@ -926,14 +953,15 @@ class ReplaceSymbolicVisitor : public PassUtils::PassVisitor<ReplaceSymbolicVisi
926953 ASR::expr_t * right_tmp = nullptr ;
927954 if (ASR::is_a<ASR::LogicalCompare_t>(*x.m_test )) {
928955 ASR::LogicalCompare_t *l = ASR::down_cast<ASR::LogicalCompare_t>(x.m_test );
956+ if (is_logical_intrinsic_symbolic (l->m_left ) && is_logical_intrinsic_symbolic (l->m_right )) {
957+ left_tmp = process_attributes (x.base .base .loc , l->m_left );
958+ right_tmp = process_attributes (x.base .base .loc , l->m_right );
959+ ASR::expr_t * test = ASRUtils::EXPR (ASR::make_LogicalCompare_t (al, x.base .base .loc , left_tmp,
960+ l->m_op , right_tmp, l->m_type , l->m_value ));
929961
930- left_tmp = process_attributes (x.base .base .loc , l->m_left );
931- right_tmp = process_attributes (x.base .base .loc , l->m_right );
932- ASR::expr_t * test = ASRUtils::EXPR (ASR::make_LogicalCompare_t (al, x.base .base .loc , left_tmp,
933- l->m_op , right_tmp, l->m_type , l->m_value ));
934-
935- ASR::stmt_t *assert_stmt = ASRUtils::STMT (ASR::make_Assert_t (al, x.base .base .loc , test, x.m_msg ));
936- pass_result.push_back (al, assert_stmt);
962+ ASR::stmt_t *assert_stmt = ASRUtils::STMT (ASR::make_Assert_t (al, x.base .base .loc , test, x.m_msg ));
963+ pass_result.push_back (al, assert_stmt);
964+ }
937965 } else if (ASR::is_a<ASR::SymbolicCompare_t>(*x.m_test )) {
938966 ASR::SymbolicCompare_t* s = ASR::down_cast<ASR::SymbolicCompare_t>(x.m_test );
939967 if (s->m_op == ASR::cmpopType::Eq || s->m_op == ASR::cmpopType::NotEq) {
@@ -949,9 +977,11 @@ class ReplaceSymbolicVisitor : public PassUtils::PassVisitor<ReplaceSymbolicVisi
949977 } else if (ASR::is_a<ASR::IntrinsicScalarFunction_t>(*x.m_test )) {
950978 ASR::IntrinsicScalarFunction_t* intrinsic_func = ASR::down_cast<ASR::IntrinsicScalarFunction_t>(x.m_test );
951979 if (intrinsic_func->m_type ->type == ASR::ttypeType::Logical) {
952- ASR::expr_t * test = process_attributes (x.base .base .loc , x.m_test );
953- ASR::stmt_t *assert_stmt = ASRUtils::STMT (ASR::make_Assert_t (al, x.base .base .loc , test, x.m_msg ));
954- pass_result.push_back (al, assert_stmt);
980+ if (is_logical_intrinsic_symbolic (x.m_test )) {
981+ ASR::expr_t * test = process_attributes (x.base .base .loc , x.m_test );
982+ ASR::stmt_t *assert_stmt = ASRUtils::STMT (ASR::make_Assert_t (al, x.base .base .loc , test, x.m_msg ));
983+ pass_result.push_back (al, assert_stmt);
984+ }
955985 }
956986 } else if (ASR::is_a<ASR::LogicalBinOp_t>(*x.m_test )) {
957987 ASR::LogicalBinOp_t* binop = ASR::down_cast<ASR::LogicalBinOp_t>(x.m_test );
0 commit comments