Skip to content

Commit 71c95ce

Browse files
authored
Fixing symbolic pass according to error caught in LFortran (#2565)
1 parent 68774de commit 71c95ce

1 file changed

Lines changed: 51 additions & 21 deletions

File tree

src/libasr/pass/replace_symbolic.cpp

Lines changed: 51 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -266,6 +266,25 @@ class ReplaceSymbolicVisitor : public PassUtils::PassVisitor<ReplaceSymbolicVisi
266266
{handle_argument(al, loc, value_01), handle_argument(al, loc, value_02)},
267267
ASRUtils::TYPE(ASR::make_Logical_t(al, loc, 4)));
268268
}
269+
270+
static inline bool is_logical_intrinsic_symbolic(ASR::expr_t* expr) {
271+
if (ASR::is_a<ASR::IntrinsicScalarFunction_t>(*expr)) {
272+
ASR::IntrinsicScalarFunction_t* intrinsic_func = ASR::down_cast<ASR::IntrinsicScalarFunction_t>(expr);
273+
int64_t intrinsic_id = intrinsic_func->m_intrinsic_id;
274+
switch (static_cast<LCompilers::ASRUtils::IntrinsicScalarFunctions>(intrinsic_id)) {
275+
case LCompilers::ASRUtils::IntrinsicScalarFunctions::SymbolicHasSymbolQ:
276+
case LCompilers::ASRUtils::IntrinsicScalarFunctions::SymbolicAddQ:
277+
case LCompilers::ASRUtils::IntrinsicScalarFunctions::SymbolicMulQ:
278+
case LCompilers::ASRUtils::IntrinsicScalarFunctions::SymbolicPowQ:
279+
case LCompilers::ASRUtils::IntrinsicScalarFunctions::SymbolicLogQ:
280+
case LCompilers::ASRUtils::IntrinsicScalarFunctions::SymbolicSinQ:
281+
return true;
282+
default:
283+
return false;
284+
}
285+
}
286+
return true;
287+
}
269288
/********************************** Utils *********************************/
270289

271290
void visit_Function(const ASR::Function_t &x) {
@@ -514,9 +533,11 @@ class ReplaceSymbolicVisitor : public PassUtils::PassVisitor<ReplaceSymbolicVisi
514533
if (intrinsic_func->m_type->type == ASR::ttypeType::SymbolicExpression) {
515534
process_intrinsic_function(x.base.base.loc, intrinsic_func, x.m_target);
516535
} else if (intrinsic_func->m_type->type == ASR::ttypeType::Logical) {
517-
ASR::expr_t* function_call = process_attributes(x.base.base.loc, x.m_value);
518-
ASR::stmt_t* stmt = ASRUtils::STMT(ASR::make_Assignment_t(al, x.base.base.loc, x.m_target, function_call, nullptr));
519-
pass_result.push_back(al, stmt);
536+
if (is_logical_intrinsic_symbolic(x.m_value)) {
537+
ASR::expr_t* function_call = process_attributes(x.base.base.loc, x.m_value);
538+
ASR::stmt_t* stmt = ASRUtils::STMT(ASR::make_Assignment_t(al, x.base.base.loc, x.m_target, function_call, nullptr));
539+
pass_result.push_back(al, stmt);
540+
}
520541
}
521542
} else if (ASR::is_a<ASR::Cast_t>(*x.m_value)) {
522543
ASR::Cast_t* cast_t = ASR::down_cast<ASR::Cast_t>(x.m_value);
@@ -676,18 +697,22 @@ class ReplaceSymbolicVisitor : public PassUtils::PassVisitor<ReplaceSymbolicVisi
676697
if (ASR::is_a<ASR::IntrinsicScalarFunction_t>(*xx.m_test)) {
677698
ASR::IntrinsicScalarFunction_t* intrinsic_func = ASR::down_cast<ASR::IntrinsicScalarFunction_t>(xx.m_test);
678699
if (intrinsic_func->m_type->type == ASR::ttypeType::Logical) {
679-
ASR::expr_t* function_call = process_attributes(xx.base.base.loc, xx.m_test);
680-
xx.m_test = function_call;
700+
if (is_logical_intrinsic_symbolic(xx.m_test)) {
701+
ASR::expr_t* function_call = process_attributes(xx.base.base.loc, xx.m_test);
702+
xx.m_test = function_call;
703+
}
681704
}
682705
} else if (ASR::is_a<ASR::LogicalNot_t>(*xx.m_test)) {
683706
ASR::LogicalNot_t* logical_not = ASR::down_cast<ASR::LogicalNot_t>(xx.m_test);
684707
if (ASR::is_a<ASR::IntrinsicScalarFunction_t>(*logical_not->m_arg)) {
685708
ASR::IntrinsicScalarFunction_t* intrinsic_func = ASR::down_cast<ASR::IntrinsicScalarFunction_t>(logical_not->m_arg);
686709
if (intrinsic_func->m_type->type == ASR::ttypeType::Logical) {
687-
ASR::expr_t* function_call = process_attributes(xx.base.base.loc, logical_not->m_arg);
688-
ASR::expr_t* new_logical_not = ASRUtils::EXPR(ASR::make_LogicalNot_t(al, xx.base.base.loc, function_call,
689-
logical_not->m_type, logical_not->m_value));
690-
xx.m_test = new_logical_not;
710+
if (is_logical_intrinsic_symbolic(logical_not->m_arg)) {
711+
ASR::expr_t* function_call = process_attributes(xx.base.base.loc, logical_not->m_arg);
712+
ASR::expr_t* new_logical_not = ASRUtils::EXPR(ASR::make_LogicalNot_t(al, xx.base.base.loc, function_call,
713+
logical_not->m_type, logical_not->m_value));
714+
xx.m_test = new_logical_not;
715+
}
691716
}
692717
}
693718
} else if (ASR::is_a<ASR::SymbolicCompare_t>(*xx.m_test)) {
@@ -784,8 +809,10 @@ class ReplaceSymbolicVisitor : public PassUtils::PassVisitor<ReplaceSymbolicVisi
784809
// Now create the FunctionCall node for basic_str
785810
print_tmp.push_back(basic_str(x.base.base.loc, target));
786811
} else if (ASR::is_a<ASR::Logical_t>(*ASRUtils::expr_type(val))) {
787-
ASR::expr_t* function_call = process_attributes(x.base.base.loc, val);
788-
print_tmp.push_back(function_call);
812+
if (is_logical_intrinsic_symbolic(val)) {
813+
ASR::expr_t* function_call = process_attributes(x.base.base.loc, val);
814+
print_tmp.push_back(function_call);
815+
}
789816
}
790817
} else if (ASR::is_a<ASR::Cast_t>(*val)) {
791818
ASR::Cast_t* cast_t = ASR::down_cast<ASR::Cast_t>(val);
@@ -926,14 +953,15 @@ class ReplaceSymbolicVisitor : public PassUtils::PassVisitor<ReplaceSymbolicVisi
926953
ASR::expr_t* right_tmp = nullptr;
927954
if (ASR::is_a<ASR::LogicalCompare_t>(*x.m_test)) {
928955
ASR::LogicalCompare_t *l = ASR::down_cast<ASR::LogicalCompare_t>(x.m_test);
956+
if (is_logical_intrinsic_symbolic(l->m_left) && is_logical_intrinsic_symbolic(l->m_right)) {
957+
left_tmp = process_attributes(x.base.base.loc, l->m_left);
958+
right_tmp = process_attributes(x.base.base.loc, l->m_right);
959+
ASR::expr_t* test = ASRUtils::EXPR(ASR::make_LogicalCompare_t(al, x.base.base.loc, left_tmp,
960+
l->m_op, right_tmp, l->m_type, l->m_value));
929961

930-
left_tmp = process_attributes(x.base.base.loc, l->m_left);
931-
right_tmp = process_attributes(x.base.base.loc, l->m_right);
932-
ASR::expr_t* test = ASRUtils::EXPR(ASR::make_LogicalCompare_t(al, x.base.base.loc, left_tmp,
933-
l->m_op, right_tmp, l->m_type, l->m_value));
934-
935-
ASR::stmt_t *assert_stmt = ASRUtils::STMT(ASR::make_Assert_t(al, x.base.base.loc, test, x.m_msg));
936-
pass_result.push_back(al, assert_stmt);
962+
ASR::stmt_t *assert_stmt = ASRUtils::STMT(ASR::make_Assert_t(al, x.base.base.loc, test, x.m_msg));
963+
pass_result.push_back(al, assert_stmt);
964+
}
937965
} else if (ASR::is_a<ASR::SymbolicCompare_t>(*x.m_test)) {
938966
ASR::SymbolicCompare_t* s = ASR::down_cast<ASR::SymbolicCompare_t>(x.m_test);
939967
if (s->m_op == ASR::cmpopType::Eq || s->m_op == ASR::cmpopType::NotEq) {
@@ -949,9 +977,11 @@ class ReplaceSymbolicVisitor : public PassUtils::PassVisitor<ReplaceSymbolicVisi
949977
} else if (ASR::is_a<ASR::IntrinsicScalarFunction_t>(*x.m_test)) {
950978
ASR::IntrinsicScalarFunction_t* intrinsic_func = ASR::down_cast<ASR::IntrinsicScalarFunction_t>(x.m_test);
951979
if (intrinsic_func->m_type->type == ASR::ttypeType::Logical) {
952-
ASR::expr_t* test = process_attributes(x.base.base.loc, x.m_test);
953-
ASR::stmt_t *assert_stmt = ASRUtils::STMT(ASR::make_Assert_t(al, x.base.base.loc, test, x.m_msg));
954-
pass_result.push_back(al, assert_stmt);
980+
if (is_logical_intrinsic_symbolic(x.m_test)) {
981+
ASR::expr_t* test = process_attributes(x.base.base.loc, x.m_test);
982+
ASR::stmt_t *assert_stmt = ASRUtils::STMT(ASR::make_Assert_t(al, x.base.base.loc, test, x.m_msg));
983+
pass_result.push_back(al, assert_stmt);
984+
}
955985
}
956986
} else if (ASR::is_a<ASR::LogicalBinOp_t>(*x.m_test)) {
957987
ASR::LogicalBinOp_t* binop = ASR::down_cast<ASR::LogicalBinOp_t>(x.m_test);

0 commit comments

Comments
 (0)