Skip to content

Commit e54b7b0

Browse files
committed
misc: add test and misc tweaks
1 parent cb454c8 commit e54b7b0

3 files changed

Lines changed: 44 additions & 8 deletions

File tree

devito/passes/clusters/misc.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -100,11 +100,9 @@ def callback(self, clusters, prefix):
100100
# If `c` is made of scalar expressions within guards, then we must keep
101101
# it close to the adjacent Clusters for correctness
102102
if c.is_scalar and c.guards:
103-
items = processed
103+
processed.append(c.rebuild(ispace=ispace, properties=properties))
104104
else:
105-
items = lifted
106-
107-
items.append(c.rebuild(ispace=ispace, properties=properties))
105+
lifted.append(c.rebuild(ispace=ispace, properties=properties))
108106

109107
return lifted + processed
110108

devito/symbolics/search.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -96,15 +96,18 @@ def visit_preorder_first_hit(self, expr: Expression) -> Iterator[Expression]:
9696

9797

9898
def search(exprs: Expression | Iterable[Expression],
99-
query: type | Callable[[Any], bool],
99+
query: type | tuple[type, ...] | Callable[[Any], bool],
100100
mode: Mode = 'unique',
101101
visit: Literal['dfs', 'bfs', 'bfs_first_hit'] = 'dfs',
102102
deep: bool = False) -> List | set[Expression]:
103103
"""Interface to Search."""
104104

105105
assert mode in ('all', 'unique'), "Unknown mode"
106106

107-
Q = (lambda obj: isinstance(obj, query)) if isinstance(query, type) else query
107+
if isinstance(query, (type, tuple)):
108+
Q = lambda obj: isinstance(obj, query)
109+
else:
110+
Q = query
108111

109112
# Search doesn't actually use a BFS (rather, a preorder DFS), but the terminology
110113
# is retained in this function's parameters for backwards compatibility

tests/test_dse.py

Lines changed: 37 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2800,8 +2800,43 @@ def test_scalar_alias_interp(self):
28002800

28012801
op.apply(time_M=3)
28022802

2803-
assert np.isclose(norm(f), 254292.75, atol=1e-1, rtol=0)
2804-
assert np.isclose(norm(s), 191.44644, atol=1e-1, rtol=0)
2803+
assert np.isclose(norm(f), 254292.75, atol=0, rtol=1e-5)
2804+
assert np.isclose(norm(s), 191.44644, atol=0, rtol=1e-4)
2805+
2806+
def test_scalar_with_cond_access(self):
2807+
grid = Grid((11, 11))
2808+
time = grid.time_dim
2809+
2810+
u = TimeFunction(name='u', grid=grid, time_order=2, space_order=2)
2811+
2812+
ct = ConditionalDimension(name='ct3', parent=time, condition=Ge(time, 2))
2813+
ct2 = ConditionalDimension(name='ct2', parent=time, factor=4)
2814+
2815+
f1 = TimeFunction(name='f1', grid=grid, save=10, time_order=0,
2816+
dimensions=(ct,), time_dim=ct, shape=(10,))
2817+
f1.data[:] = np.arange(10)
2818+
2819+
eq0 = Eq(u.forward, u + cos(f1))
2820+
eq1 = Eq(u.forward, u.forward + sin(time), implicit_dims=ct2)
2821+
eq2 = Eq(u.forward, u.forward - sin(f1))
2822+
2823+
op = Operator([eq0, eq1, eq2])
2824+
2825+
cond = FindNodes(Conditional).visit(op)
2826+
assert len(cond) == 3
2827+
2828+
# # Each guard should have its own alias for cos/sin(f1[time-2])
2829+
scalars = [i for i in FindSymbols().visit(op) if isinstance(i, Temp)]
2830+
assert len(scalars) == 3
2831+
2832+
assert_structure(
2833+
op,
2834+
['t', 't,x,y', 't,x,y', 't,x,y'],
2835+
'txyxyxy'
2836+
)
2837+
2838+
# This would segfault without the right placement of the alias
2839+
op.apply(time_M=12)
28052840

28062841

28072842
class TestIsoAcoustic:

0 commit comments

Comments
 (0)