Skip to content

Commit 8546abe

Browse files
authored
Implementing Symbolic Has Query Method (#2336)
1 parent 307da8b commit 8546abe

6 files changed

Lines changed: 238 additions & 40 deletions

File tree

integration_tests/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -713,6 +713,7 @@ RUN(NAME symbolics_06 LABELS cpython_sym c_sym llvm_sym NOFAST)
713713
RUN(NAME symbolics_07 LABELS cpython_sym c_sym llvm_sym NOFAST)
714714
RUN(NAME symbolics_08 LABELS cpython_sym c_sym llvm_sym)
715715
RUN(NAME symbolics_09 LABELS cpython_sym c_sym llvm_sym NOFAST)
716+
RUN(NAME symbolics_10 LABELS cpython_sym c_sym llvm_sym NOFAST)
716717

717718
RUN(NAME sizeof_01 LABELS llvm c
718719
EXTRAFILES sizeof_01b.c)

integration_tests/symbolics_10.py

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,26 @@
1+
from sympy import Symbol, sin, pi
2+
from lpython import S
3+
4+
def test_attributes():
5+
w: S = pi
6+
x: S = Symbol('x')
7+
y: S = Symbol('y')
8+
z: S = sin(x)
9+
10+
# test has
11+
assert(w.has(x) == False)
12+
assert(y.has(x) == False)
13+
assert(x.has(x) == True)
14+
assert(x.has(x) == z.has(x))
15+
16+
# test has 2
17+
assert(sin(x).has(x) == True)
18+
assert(sin(x).has(y) == False)
19+
assert(sin(Symbol("x")).has(x) == True)
20+
assert(sin(Symbol("x")).has(y) == False)
21+
assert(sin(Symbol("x")).has(Symbol("x")) == True)
22+
assert(sin(Symbol("x")).has(Symbol("y")) == False)
23+
assert(sin(Symbol("x")).has(Symbol("x")) != sin(Symbol("x")).has(Symbol("y")))
24+
assert(sin(Symbol("x")).has(Symbol("x")) == sin(Symbol("y")).has(Symbol("y")))
25+
26+
test_attributes()

src/libasr/pass/intrinsic_function_registry.h

Lines changed: 57 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -76,6 +76,7 @@ enum class IntrinsicScalarFunctions : int64_t {
7676
SymbolicLog,
7777
SymbolicExp,
7878
SymbolicAbs,
79+
SymbolicHasSymbolQ,
7980
// ...
8081
};
8182

@@ -135,6 +136,7 @@ inline std::string get_intrinsic_name(int x) {
135136
INTRINSIC_NAME_CASE(SymbolicLog)
136137
INTRINSIC_NAME_CASE(SymbolicExp)
137138
INTRINSIC_NAME_CASE(SymbolicAbs)
139+
INTRINSIC_NAME_CASE(SymbolicHasSymbolQ)
138140
default : {
139141
throw LCompilersException("pickle: intrinsic_id not implemented");
140142
}
@@ -2908,6 +2910,56 @@ namespace SymbolicInteger {
29082910

29092911
} // namespace SymbolicInteger
29102912

2913+
namespace SymbolicHasSymbolQ {
2914+
static inline void verify_args(const ASR::IntrinsicScalarFunction_t& x,
2915+
diag::Diagnostics& diagnostics) {
2916+
ASRUtils::require_impl(x.n_args == 2, "Intrinsic function SymbolicHasSymbolQ"
2917+
"accepts exactly 2 arguments", x.base.base.loc, diagnostics);
2918+
2919+
ASR::ttype_t* left_type = ASRUtils::expr_type(x.m_args[0]);
2920+
ASR::ttype_t* right_type = ASRUtils::expr_type(x.m_args[1]);
2921+
2922+
ASRUtils::require_impl(ASR::is_a<ASR::SymbolicExpression_t>(*left_type) &&
2923+
ASR::is_a<ASR::SymbolicExpression_t>(*right_type),
2924+
"Both arguments of SymbolicHasSymbolQ must be of type SymbolicExpression",
2925+
x.base.base.loc, diagnostics);
2926+
}
2927+
2928+
static inline ASR::expr_t* eval_SymbolicHasSymbolQ(Allocator &/*al*/,
2929+
const Location &/*loc*/, ASR::ttype_t *, Vec<ASR::expr_t*> &/*args*/) {
2930+
/*TODO*/
2931+
return nullptr;
2932+
}
2933+
2934+
static inline ASR::asr_t* create_SymbolicHasSymbolQ(Allocator& al,
2935+
const Location& loc, Vec<ASR::expr_t*>& args,
2936+
const std::function<void (const std::string &, const Location &)> err) {
2937+
2938+
if (args.size() != 2) {
2939+
err("Intrinsic function SymbolicHasSymbolQ accepts exactly 2 arguments", loc);
2940+
}
2941+
2942+
for (size_t i = 0; i < args.size(); i++) {
2943+
ASR::ttype_t* argtype = ASRUtils::expr_type(args[i]);
2944+
if(!ASR::is_a<ASR::SymbolicExpression_t>(*argtype)) {
2945+
err("Arguments of SymbolicHasSymbolQ function must be of type SymbolicExpression",
2946+
args[i]->base.loc);
2947+
}
2948+
}
2949+
2950+
Vec<ASR::expr_t*> arg_values;
2951+
arg_values.reserve(al, args.size());
2952+
for( size_t i = 0; i < args.size(); i++ ) {
2953+
arg_values.push_back(al, ASRUtils::expr_value(args[i]));
2954+
}
2955+
2956+
ASR::expr_t* compile_time_value = eval_SymbolicHasSymbolQ(al, loc, logical, arg_values);
2957+
return ASR::make_IntrinsicScalarFunction_t(al, loc,
2958+
static_cast<int64_t>(IntrinsicScalarFunctions::SymbolicHasSymbolQ),
2959+
args.p, args.size(), 0, logical, compile_time_value);
2960+
}
2961+
} // namespace SymbolicHasSymbolQ
2962+
29112963
#define create_symbolic_unary_macro(X) \
29122964
namespace X { \
29132965
static inline void verify_args(const ASR::IntrinsicScalarFunction_t& x, \
@@ -3057,6 +3109,8 @@ namespace IntrinsicScalarFunctionRegistry {
30573109
{nullptr, &SymbolicExp::verify_args}},
30583110
{static_cast<int64_t>(IntrinsicScalarFunctions::SymbolicAbs),
30593111
{nullptr, &SymbolicAbs::verify_args}},
3112+
{static_cast<int64_t>(IntrinsicScalarFunctions::SymbolicHasSymbolQ),
3113+
{nullptr, &SymbolicHasSymbolQ::verify_args}},
30603114
};
30613115

30623116
static const std::map<int64_t, std::string>& intrinsic_function_id_to_name = {
@@ -3157,6 +3211,8 @@ namespace IntrinsicScalarFunctionRegistry {
31573211
"SymbolicExp"},
31583212
{static_cast<int64_t>(IntrinsicScalarFunctions::SymbolicAbs),
31593213
"SymbolicAbs"},
3214+
{static_cast<int64_t>(IntrinsicScalarFunctions::SymbolicHasSymbolQ),
3215+
"SymbolicHasSymbolQ"},
31603216
};
31613217

31623218

@@ -3210,6 +3266,7 @@ namespace IntrinsicScalarFunctionRegistry {
32103266
{"SymbolicLog", {&SymbolicLog::create_SymbolicLog, &SymbolicLog::eval_SymbolicLog}},
32113267
{"SymbolicExp", {&SymbolicExp::create_SymbolicExp, &SymbolicExp::eval_SymbolicExp}},
32123268
{"SymbolicAbs", {&SymbolicAbs::create_SymbolicAbs, &SymbolicAbs::eval_SymbolicAbs}},
3269+
{"has", {&SymbolicHasSymbolQ::create_SymbolicHasSymbolQ, &SymbolicHasSymbolQ::eval_SymbolicHasSymbolQ}},
32133270
};
32143271

32153272
static inline bool is_intrinsic_function(const std::string& name) {

src/libasr/pass/replace_symbolic.cpp

Lines changed: 135 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -626,12 +626,92 @@ class ReplaceSymbolicVisitor : public PassUtils::PassVisitor<ReplaceSymbolicVisi
626626
return module_scope->get_symbol(name);
627627
}
628628

629+
ASR::expr_t* process_attributes(Allocator &al, const Location &loc, ASR::expr_t* expr,
630+
SymbolTable* module_scope) {
631+
if (ASR::is_a<ASR::IntrinsicScalarFunction_t>(*expr)) {
632+
ASR::IntrinsicScalarFunction_t* intrinsic_func = ASR::down_cast<ASR::IntrinsicScalarFunction_t>(expr);
633+
int64_t intrinsic_id = intrinsic_func->m_intrinsic_id;
634+
switch (static_cast<LCompilers::ASRUtils::IntrinsicScalarFunctions>(intrinsic_id)) {
635+
case LCompilers::ASRUtils::IntrinsicScalarFunctions::SymbolicHasSymbolQ: {
636+
std::string name = "basic_has_symbol";
637+
symbolic_dependencies.push_back(name);
638+
if (!module_scope->get_symbol(name)) {
639+
std::string header = "symengine/cwrapper.h";
640+
SymbolTable* fn_symtab = al.make_new<SymbolTable>(module_scope);
641+
642+
Vec<ASR::expr_t*> args;
643+
args.reserve(al, 1);
644+
ASR::symbol_t* arg1 = ASR::down_cast<ASR::symbol_t>(ASR::make_Variable_t(
645+
al, loc, fn_symtab, s2c(al, "_lpython_return_variable"), nullptr, 0, ASR::intentType::ReturnVar,
646+
nullptr, nullptr, ASR::storage_typeType::Default, ASRUtils::TYPE(ASR::make_Logical_t(al, loc, 4)),
647+
nullptr, ASR::abiType::BindC, ASR::Public, ASR::presenceType::Required, false));
648+
fn_symtab->add_symbol(s2c(al, "_lpython_return_variable"), arg1);
649+
ASR::symbol_t* arg2 = ASR::down_cast<ASR::symbol_t>(ASR::make_Variable_t(
650+
al, loc, fn_symtab, s2c(al, "x"), nullptr, 0, ASR::intentType::In,
651+
nullptr, nullptr, ASR::storage_typeType::Default, ASRUtils::TYPE(ASR::make_CPtr_t(al, loc)),
652+
nullptr, ASR::abiType::BindC, ASR::Public, ASR::presenceType::Required, true));
653+
fn_symtab->add_symbol(s2c(al, "x"), arg2);
654+
args.push_back(al, ASRUtils::EXPR(ASR::make_Var_t(al, loc, arg2)));
655+
ASR::symbol_t* arg3 = ASR::down_cast<ASR::symbol_t>(ASR::make_Variable_t(
656+
al, loc, fn_symtab, s2c(al, "y"), nullptr, 0, ASR::intentType::In,
657+
nullptr, nullptr, ASR::storage_typeType::Default, ASRUtils::TYPE(ASR::make_CPtr_t(al, loc)),
658+
nullptr, ASR::abiType::BindC, ASR::Public, ASR::presenceType::Required, true));
659+
fn_symtab->add_symbol(s2c(al, "y"), arg3);
660+
args.push_back(al, ASRUtils::EXPR(ASR::make_Var_t(al, loc, arg3)));
661+
662+
Vec<ASR::stmt_t*> body;
663+
body.reserve(al, 1);
664+
665+
Vec<char*> dep;
666+
dep.reserve(al, 1);
667+
668+
ASR::expr_t* return_var = ASRUtils::EXPR(ASR::make_Var_t(al, loc, fn_symtab->get_symbol("_lpython_return_variable")));
669+
ASR::asr_t* subrout = ASRUtils::make_Function_t_util(al, loc,
670+
fn_symtab, s2c(al, name), dep.p, dep.n, args.p, args.n, body.p, body.n,
671+
return_var, ASR::abiType::BindC, ASR::accessType::Public,
672+
ASR::deftypeType::Interface, s2c(al, name), false, false, false,
673+
false, false, nullptr, 0, false, false, false, s2c(al, header));
674+
ASR::symbol_t* symbol = ASR::down_cast<ASR::symbol_t>(subrout);
675+
module_scope->add_symbol(s2c(al, name), symbol);
676+
}
677+
678+
ASR::symbol_t* basic_has_symbol = module_scope->get_symbol(name);
679+
ASR::expr_t* value1 = handle_argument(al, loc, intrinsic_func->m_args[0]);
680+
ASR::expr_t* value2 = handle_argument(al, loc, intrinsic_func->m_args[1]);
681+
Vec<ASR::call_arg_t> call_args;
682+
call_args.reserve(al, 1);
683+
ASR::call_arg_t call_arg1, call_arg2;
684+
call_arg1.loc = loc;
685+
call_arg1.m_value = value1;
686+
call_args.push_back(al, call_arg1);
687+
call_arg2.loc = loc;
688+
call_arg2.m_value = value2;
689+
call_args.push_back(al, call_arg2);
690+
return ASRUtils::EXPR(ASRUtils::make_FunctionCall_t_util(al, loc,
691+
basic_has_symbol, basic_has_symbol, call_args.p, call_args.n,
692+
ASRUtils::TYPE(ASR::make_Logical_t(al, loc, 4)), nullptr, nullptr));
693+
break;
694+
}
695+
default: {
696+
throw LCompilersException("IntrinsicFunction: `"
697+
+ ASRUtils::get_intrinsic_name(intrinsic_id)
698+
+ "` is not implemented");
699+
}
700+
}
701+
}
702+
return expr;
703+
}
704+
629705
void visit_Assignment(const ASR::Assignment_t &x) {
630706
SymbolTable* module_scope = current_scope->parent;
631707
if (ASR::is_a<ASR::IntrinsicScalarFunction_t>(*x.m_value)) {
632708
ASR::IntrinsicScalarFunction_t* intrinsic_func = ASR::down_cast<ASR::IntrinsicScalarFunction_t>(x.m_value);
633709
if (intrinsic_func->m_type->type == ASR::ttypeType::SymbolicExpression) {
634710
process_intrinsic_function(al, x.base.base.loc, intrinsic_func, module_scope, x.m_target);
711+
} else if (intrinsic_func->m_type->type == ASR::ttypeType::Logical) {
712+
ASR::expr_t* function_call = process_attributes(al, x.base.base.loc, x.m_value, module_scope);
713+
ASR::stmt_t* stmt = ASRUtils::STMT(ASR::make_Assignment_t(al, x.base.base.loc, x.m_target, function_call, nullptr));
714+
pass_result.push_back(al, stmt);
635715
}
636716
} else if (ASR::is_a<ASR::Cast_t>(*x.m_value)) {
637717
ASR::Cast_t* cast_t = ASR::down_cast<ASR::Cast_t>(x.m_value);
@@ -770,37 +850,42 @@ class ReplaceSymbolicVisitor : public PassUtils::PassVisitor<ReplaceSymbolicVisi
770850
basic_str_sym, basic_str_sym, call_args.p, call_args.n,
771851
ASRUtils::TYPE(ASR::make_Character_t(al, x.base.base.loc, 1, -2, nullptr)), nullptr, nullptr));
772852
print_tmp.push_back(function_call);
773-
} else if (ASR::is_a<ASR::IntrinsicScalarFunction_t>(*val) && ASR::is_a<ASR::SymbolicExpression_t>(*ASRUtils::expr_type(val))) {
853+
} else if (ASR::is_a<ASR::IntrinsicScalarFunction_t>(*val)) {
774854
ASR::IntrinsicScalarFunction_t* intrinsic_func = ASR::down_cast<ASR::IntrinsicScalarFunction_t>(val);
775-
ASR::ttype_t *type = ASRUtils::TYPE(ASR::make_SymbolicExpression_t(al, x.base.base.loc));
776-
std::string symengine_var = symengine_stack.push();
777-
ASR::symbol_t *arg = ASR::down_cast<ASR::symbol_t>(ASR::make_Variable_t(
778-
al, x.base.base.loc, current_scope, s2c(al, symengine_var), nullptr, 0, ASR::intentType::Local,
779-
nullptr, nullptr, ASR::storage_typeType::Default, type, nullptr,
780-
ASR::abiType::BindC, ASR::Public, ASR::presenceType::Required, false));
781-
current_scope->add_symbol(s2c(al, symengine_var), arg);
782-
for (auto &item : current_scope->get_scope()) {
783-
if (ASR::is_a<ASR::Variable_t>(*item.second)) {
784-
ASR::Variable_t *s = ASR::down_cast<ASR::Variable_t>(item.second);
785-
this->visit_Variable(*s);
855+
if (ASR::is_a<ASR::SymbolicExpression_t>(*ASRUtils::expr_type(val))) {
856+
ASR::ttype_t *type = ASRUtils::TYPE(ASR::make_SymbolicExpression_t(al, x.base.base.loc));
857+
std::string symengine_var = symengine_stack.push();
858+
ASR::symbol_t *arg = ASR::down_cast<ASR::symbol_t>(ASR::make_Variable_t(
859+
al, x.base.base.loc, current_scope, s2c(al, symengine_var), nullptr, 0, ASR::intentType::Local,
860+
nullptr, nullptr, ASR::storage_typeType::Default, type, nullptr,
861+
ASR::abiType::BindC, ASR::Public, ASR::presenceType::Required, false));
862+
current_scope->add_symbol(s2c(al, symengine_var), arg);
863+
for (auto &item : current_scope->get_scope()) {
864+
if (ASR::is_a<ASR::Variable_t>(*item.second)) {
865+
ASR::Variable_t *s = ASR::down_cast<ASR::Variable_t>(item.second);
866+
this->visit_Variable(*s);
867+
}
786868
}
787-
}
788869

789-
ASR::expr_t* target = ASRUtils::EXPR(ASR::make_Var_t(al, x.base.base.loc, arg));
790-
process_intrinsic_function(al, x.base.base.loc, intrinsic_func, module_scope, target);
870+
ASR::expr_t* target = ASRUtils::EXPR(ASR::make_Var_t(al, x.base.base.loc, arg));
871+
process_intrinsic_function(al, x.base.base.loc, intrinsic_func, module_scope, target);
791872

792-
// Now create the FunctionCall node for basic_str
793-
ASR::symbol_t* basic_str_sym = declare_basic_str_function(al, x.base.base.loc, module_scope);
794-
Vec<ASR::call_arg_t> call_args;
795-
call_args.reserve(al, 1);
796-
ASR::call_arg_t call_arg;
797-
call_arg.loc = x.base.base.loc;
798-
call_arg.m_value = target;
799-
call_args.push_back(al, call_arg);
800-
ASR::expr_t* function_call = ASRUtils::EXPR(ASRUtils::make_FunctionCall_t_util(al, x.base.base.loc,
801-
basic_str_sym, basic_str_sym, call_args.p, call_args.n,
802-
ASRUtils::TYPE(ASR::make_Character_t(al, x.base.base.loc, 1, -2, nullptr)), nullptr, nullptr));
803-
print_tmp.push_back(function_call);
873+
// Now create the FunctionCall node for basic_str
874+
ASR::symbol_t* basic_str_sym = declare_basic_str_function(al, x.base.base.loc, module_scope);
875+
Vec<ASR::call_arg_t> call_args;
876+
call_args.reserve(al, 1);
877+
ASR::call_arg_t call_arg;
878+
call_arg.loc = x.base.base.loc;
879+
call_arg.m_value = target;
880+
call_args.push_back(al, call_arg);
881+
ASR::expr_t* function_call = ASRUtils::EXPR(ASRUtils::make_FunctionCall_t_util(al, x.base.base.loc,
882+
basic_str_sym, basic_str_sym, call_args.p, call_args.n,
883+
ASRUtils::TYPE(ASR::make_Character_t(al, x.base.base.loc, 1, -2, nullptr)), nullptr, nullptr));
884+
print_tmp.push_back(function_call);
885+
} else if (ASR::is_a<ASR::Logical_t>(*ASRUtils::expr_type(val))) {
886+
ASR::expr_t* function_call = process_attributes(al, x.base.base.loc, val, module_scope);
887+
print_tmp.push_back(function_call);
888+
}
804889
} else if (ASR::is_a<ASR::Cast_t>(*val)) {
805890
ASR::Cast_t* cast_t = ASR::down_cast<ASR::Cast_t>(val);
806891
if(cast_t->m_kind != ASR::cast_kindType::IntegerToSymbolicExpression) return;
@@ -951,20 +1036,34 @@ class ReplaceSymbolicVisitor : public PassUtils::PassVisitor<ReplaceSymbolicVisi
9511036
}
9521037

9531038
void visit_Assert(const ASR::Assert_t &x) {
954-
if (!ASR::is_a<ASR::SymbolicCompare_t>(*x.m_test)) return;
955-
ASR::SymbolicCompare_t *s = ASR::down_cast<ASR::SymbolicCompare_t>(x.m_test);
9561039
SymbolTable* module_scope = current_scope->parent;
9571040
ASR::expr_t* left_tmp = nullptr;
9581041
ASR::expr_t* right_tmp = nullptr;
1042+
if (ASR::is_a<ASR::LogicalCompare_t>(*x.m_test)) {
1043+
ASR::LogicalCompare_t *l = ASR::down_cast<ASR::LogicalCompare_t>(x.m_test);
1044+
1045+
left_tmp = process_attributes(al, x.base.base.loc, l->m_left, module_scope);
1046+
right_tmp = process_attributes(al, x.base.base.loc, l->m_right, module_scope);
1047+
ASR::expr_t* test = ASRUtils::EXPR(ASR::make_LogicalCompare_t(al, x.base.base.loc, left_tmp,
1048+
l->m_op, right_tmp, l->m_type, l->m_value));
1049+
1050+
ASR::stmt_t *assert_stmt = ASRUtils::STMT(ASR::make_Assert_t(al, x.base.base.loc, test, x.m_msg));
1051+
pass_result.push_back(al, assert_stmt);
1052+
} else if(ASR::is_a<ASR::SymbolicCompare_t>(*x.m_test)) {
1053+
ASR::SymbolicCompare_t *s = ASR::down_cast<ASR::SymbolicCompare_t>(x.m_test);
1054+
SymbolTable* module_scope = current_scope->parent;
1055+
ASR::expr_t* left_tmp = nullptr;
1056+
ASR::expr_t* right_tmp = nullptr;
9591057

960-
ASR::symbol_t* basic_str_sym = declare_basic_str_function(al, x.base.base.loc, module_scope);
961-
left_tmp = process_with_basic_str(al, x.base.base.loc, s->m_left, basic_str_sym);
962-
right_tmp = process_with_basic_str(al, x.base.base.loc, s->m_right, basic_str_sym);
963-
ASR::expr_t* test = ASRUtils::EXPR(ASR::make_StringCompare_t(al, x.base.base.loc, left_tmp,
964-
s->m_op, right_tmp, s->m_type, s->m_value));
1058+
ASR::symbol_t* basic_str_sym = declare_basic_str_function(al, x.base.base.loc, module_scope);
1059+
left_tmp = process_with_basic_str(al, x.base.base.loc, s->m_left, basic_str_sym);
1060+
right_tmp = process_with_basic_str(al, x.base.base.loc, s->m_right, basic_str_sym);
1061+
ASR::expr_t* test = ASRUtils::EXPR(ASR::make_StringCompare_t(al, x.base.base.loc, left_tmp,
1062+
s->m_op, right_tmp, s->m_type, s->m_value));
9651063

966-
ASR::stmt_t *assert_stmt = ASRUtils::STMT(ASR::make_Assert_t(al, x.base.base.loc, test, x.m_msg));
967-
pass_result.push_back(al, assert_stmt);
1064+
ASR::stmt_t *assert_stmt = ASRUtils::STMT(ASR::make_Assert_t(al, x.base.base.loc, test, x.m_msg));
1065+
pass_result.push_back(al, assert_stmt);
1066+
}
9681067
}
9691068
};
9701069

0 commit comments

Comments
 (0)