Skip to content

Commit dcaf253

Browse files
authored
Merge pull request #2404 from anutosh491/Fixing_assert
Added support for `visit_Assert` through `basic_eq`
2 parents 35a480b + 06e0a80 commit dcaf253

2 files changed

Lines changed: 31 additions & 9 deletions

File tree

integration_tests/symbolics_01.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,4 +13,10 @@ def main0():
1313
assert(z == pi + y)
1414
assert(z != S(2)*pi + y)
1515

16+
# testing PR 2404
17+
p: S = Symbol('pi')
18+
print(p)
19+
print(p != pi)
20+
assert(p != pi)
21+
1622
main0()

src/libasr/pass/replace_symbolic.cpp

Lines changed: 25 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1706,16 +1706,32 @@ class ReplaceSymbolicVisitor : public PassUtils::PassVisitor<ReplaceSymbolicVisi
17061706
ASR::stmt_t *assert_stmt = ASRUtils::STMT(ASR::make_Assert_t(al, x.base.base.loc, test, x.m_msg));
17071707
pass_result.push_back(al, assert_stmt);
17081708
} else if (ASR::is_a<ASR::SymbolicCompare_t>(*x.m_test)) {
1709-
ASR::SymbolicCompare_t *s = ASR::down_cast<ASR::SymbolicCompare_t>(x.m_test);
1710-
1711-
ASR::symbol_t* basic_str_sym = declare_basic_str_function(al, x.base.base.loc, module_scope);
1712-
left_tmp = process_with_basic_str(al, x.base.base.loc, s->m_left, basic_str_sym);
1713-
right_tmp = process_with_basic_str(al, x.base.base.loc, s->m_right, basic_str_sym);
1714-
ASR::expr_t* test = ASRUtils::EXPR(ASR::make_StringCompare_t(al, x.base.base.loc, left_tmp,
1715-
s->m_op, right_tmp, s->m_type, s->m_value));
1709+
ASR::SymbolicCompare_t* s = ASR::down_cast<ASR::SymbolicCompare_t>(x.m_test);
1710+
if (s->m_op == ASR::cmpopType::Eq || s->m_op == ASR::cmpopType::NotEq) {
1711+
ASR::symbol_t* sym = nullptr;
1712+
if (s->m_op == ASR::cmpopType::Eq) {
1713+
sym = declare_basic_eq_function(al, x.base.base.loc, module_scope);
1714+
} else {
1715+
sym = declare_basic_neq_function(al, x.base.base.loc, module_scope);
1716+
}
1717+
ASR::expr_t* value1 = handle_argument(al, x.base.base.loc, s->m_left);
1718+
ASR::expr_t* value2 = handle_argument(al, x.base.base.loc, s->m_right);
1719+
Vec<ASR::call_arg_t> call_args;
1720+
call_args.reserve(al, 1);
1721+
ASR::call_arg_t call_arg1, call_arg2;
1722+
call_arg1.loc = x.base.base.loc;
1723+
call_arg1.m_value = value1;
1724+
call_arg2.loc = x.base.base.loc;
1725+
call_arg2.m_value = value2;
1726+
call_args.push_back(al, call_arg1);
1727+
call_args.push_back(al, call_arg2);
1728+
ASR::expr_t* function_call = ASRUtils::EXPR(ASRUtils::make_FunctionCall_t_util(al, x.base.base.loc,
1729+
sym, sym, call_args.p, call_args.n,
1730+
ASRUtils::TYPE(ASR::make_Logical_t(al, x.base.base.loc, 4)), nullptr, nullptr));
17161731

1717-
ASR::stmt_t *assert_stmt = ASRUtils::STMT(ASR::make_Assert_t(al, x.base.base.loc, test, x.m_msg));
1718-
pass_result.push_back(al, assert_stmt);
1732+
ASR::stmt_t *assert_stmt = ASRUtils::STMT(ASR::make_Assert_t(al, x.base.base.loc, function_call, x.m_msg));
1733+
pass_result.push_back(al, assert_stmt);
1734+
}
17191735
} else if (ASR::is_a<ASR::IntrinsicScalarFunction_t>(*x.m_test)) {
17201736
ASR::IntrinsicScalarFunction_t* intrinsic_func = ASR::down_cast<ASR::IntrinsicScalarFunction_t>(x.m_test);
17211737
if (intrinsic_func->m_type->type == ASR::ttypeType::Logical) {

0 commit comments

Comments
 (0)