55from sympy import Mul # noqa
66
77from conftest import ( # noqa
8- _R , EVAL , assert_blocking , assert_structure , check_array , get_arrays , get_params ,
9- skipif
8+ _R , EVAL , assert_blocking , assert_structure , body0 , check_array , get_arrays ,
9+ get_params , skipif
1010)
1111from devito import ( # noqa
1212 NODE , Abs , ConditionalDimension , Constant , DefaultDimension , Derivative , Dimension ,
@@ -348,8 +348,8 @@ def test_scalar_cond(self):
348348 trees = retrieve_iteration_tree (op )
349349
350350 assert len (trees ) == 3
351- assert_structure (op , ['t' , 't,x,y ' , 't,x,y' ], 'txyxy' )
352- assert trees [0 ].dimensions == [time ]
351+ assert_structure (op , ['t,x,y ' , 't' , 't,x,y' ], 'txyxy' )
352+ assert trees [1 ].dimensions == [time ]
353353
354354
355355class TestAliases :
@@ -2552,6 +2552,7 @@ def test_invariants_with_conditional(self):
25522552 eqn = Eq (u , u - (cos (time_sub * factor * f ) * uf ))
25532553
25542554 op = Operator (eqn , opt = 'advanced' )
2555+
25552556 assert_structure (op , ['t' , 't,fd' , 't,fd,x,y' ], 't,fd,x,y' )
25562557 # Make sure it compiles
25572558 _ = op .cfunction
@@ -2700,10 +2701,12 @@ def test_split_cond(self):
27002701 eq2 = Eq (u .forward , u .forward + cos (time ), implicit_dims = ct )
27012702
27022703 op = Operator ([eq0 , eq1 , eq2 ])
2704+ op (time = 5 )
2705+
27032706 cond = FindNodes (Conditional ).visit (op )
27042707 assert len (cond ) == 3
27052708 # The alias should have been lifted out of the condition
2706- assert 'float r0 = cos(time);' in str (op . body . body [ 0 ] )
2709+ assert 'float r0 = cos(time);' in str (body0 ( op ) )
27072710 scalars = [i for i in FindSymbols ().visit (op ) if isinstance (i , Temp )]
27082711 assert len (scalars ) == 1
27092712
@@ -2721,10 +2724,12 @@ def test_split_cond_multi_alias(self):
27212724 eq2 = Eq (u .forward , u .forward + cos (time ) - sin (time ), implicit_dims = ct )
27222725
27232726 op = Operator ([eq0 , eq1 , eq2 ])
2727+ op (time = 5 )
2728+
27242729 cond = FindNodes (Conditional ).visit (op )
27252730 assert len (cond ) == 3
27262731 # The alias should have been lifted out of the condition
2727- assert 'float r3 = cos(time);' in str (op . body . body [ 0 ] )
2732+ assert 'float r3 = cos(time);' in str (body0 ( op ) )
27282733 scalars = [i for i in FindSymbols ().visit (op ) if isinstance (i , Temp )]
27292734 assert len (scalars ) == 5
27302735
@@ -2743,6 +2748,7 @@ def test_multi_cond_no_split(self):
27432748 eq2 = Eq (u .forward , u .forward - sin (time ), implicit_dims = ct )
27442749
27452750 op = Operator ([eq0 , eq1 , eq2 ])
2751+ op (time = 5 )
27462752
27472753 assert_structure (
27482754 op ,
@@ -2751,7 +2757,7 @@ def test_multi_cond_no_split(self):
27512757 )
27522758
27532759 scalars = [i for i in FindSymbols ().visit (op ) if isinstance (i , Temp )]
2754- assert len (scalars ) == 4
2760+ assert len (scalars ) == 3
27552761
27562762 def test_alias_with_conditional (self ):
27572763 grid = Grid ((11 , 11 ))
@@ -2767,6 +2773,8 @@ def test_alias_with_conditional(self):
27672773 eq2 = Eq (u .forward , u .forward + cos (ct ), implicit_dims = ct )
27682774
27692775 op = Operator ([eq0 , eq1 , eq2 ])
2776+ op (time = 5 )
2777+
27702778 cond = FindNodes (Conditional ).visit (op )
27712779 assert len (cond ) == 3
27722780
0 commit comments