Skip to content

Commit 7fa5e9a

Browse files
committed
compiler: improve lifting processing to avoid aliases missplacement
1 parent 5581b6d commit 7fa5e9a

4 files changed

Lines changed: 36 additions & 15 deletions

File tree

devito/passes/clusters/aliases.py

Lines changed: 8 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -132,7 +132,8 @@ def _aliases_from_clusters(self, cgroup, exclude, meta):
132132
# AliasList -> Schedule
133133
schedule = lower_aliases(aliases, meta, self.opt_maxpar)
134134

135-
variants.append(Variant(schedule, exprs))
135+
if schedule:
136+
variants.append(Variant(schedule, exprs))
136137

137138
if not variants:
138139
return []
@@ -147,7 +148,7 @@ def _aliases_from_clusters(self, cgroup, exclude, meta):
147148
# Schedule -> [Clusters]_k
148149
processed, subs = lower_schedule(schedule, meta, self.sregistry,
149150
self.opt_ftemps, self.opt_min_dtype,
150-
self.opt_minmem)
151+
self.opt_minmem, nclusters=len(cgroup))
151152

152153
# [Clusters]_k -> [Clusters]_k (optimization)
153154
if self.opt_multisubdomain:
@@ -272,7 +273,6 @@ def _do_generate(self, exprs, exclude, cbk_search, cbk_compose=None):
272273
free_symbols = i.free_symbols
273274
if {a.function for a in free_symbols} & exclude:
274275
continue
275-
276276
mapper.add(i, make, terms)
277277

278278
return mapper
@@ -853,7 +853,7 @@ def optimize_schedule_rotations(schedule, sregistry):
853853

854854

855855
def lower_schedule(schedule, meta, sregistry, opt_ftemps, opt_min_dtype,
856-
opt_minmem):
856+
opt_minmem, nclusters=1):
857857
"""
858858
Turn a Schedule into a sequence of Clusters.
859859
"""
@@ -929,20 +929,21 @@ def lower_schedule(schedule, meta, sregistry, opt_ftemps, opt_min_dtype,
929929
# Degenerate case: scalar expression
930930
assert writeto.size == 0
931931

932-
guards = None
933932
is_cond = any(isinstance(d, (SubsamplingFactor, ConditionalDimension))
934933
for d in pivot.free_symbols)
935-
if meta.guards and is_cond:
934+
if meta.guards and is_cond and nclusters > 1:
936935
# Scalar alias that depends on a guard, unsafe to lift out of the guard
937936
# Do not alias
938937
expression = None
939938
callback = lambda idx: uxreplace(pivot, subs) # noqa: B023
940939
else:
941940
dtype = sympy_dtype(pivot, base=meta.dtype, smin=opt_min_dtype)
942-
obj = Temp(name=name, dtype=dtype)
941+
obj = Temp(name=name, dtype=dtype, is_const=True)
943942
expression = Eq(obj, uxreplace(pivot, subs))
944943

945944
callback = lambda idx: obj # noqa: B023
945+
# Only keep the guard if there is no cross-cluster reuse of the scalar
946+
guards = meta.guards if nclusters == 1 else None
946947

947948
# Create the substitution rules for the aliasing expressions
948949
subs.update({

devito/passes/clusters/misc.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -102,7 +102,14 @@ def callback(self, clusters, prefix):
102102
# unless the guard is for an outer dimension
103103
guards = {} if c.is_scalar and not (prefix[:-1] and c.guards) else c.guards
104104

105-
lifted.append(c.rebuild(ispace=ispace, properties=properties, guards=guards))
105+
_lifted = c.rebuild(ispace=ispace, properties=properties, guards=guards)
106+
if clusters[max(n-1, 0)].guards != guards and _lifted.is_scalar:
107+
# Heuristic: if the lifted Cluster has different guards than the
108+
# previous one, then we are likely to end up with a separate
109+
# Cluster, hence give up on lifting
110+
processed.append(_lifted)
111+
else:
112+
lifted.append(_lifted)
106113

107114
return lifted + processed
108115

examples/performance/01_gpu.ipynb

Lines changed: 11 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -142,8 +142,13 @@
142142
"name": "stderr",
143143
"output_type": "stream",
144144
"text": [
145-
"NUMA domain count autodetection failed, assuming 1\n",
146-
"Operator `Kernel` ran in 0.01 s\n",
145+
"NUMA domain count autodetection failed, assuming 1\n"
146+
]
147+
},
148+
{
149+
"name": "stderr",
150+
"output_type": "stream",
151+
"text": [
147152
"Operator `Kernel` ran in 0.01 s\n"
148153
]
149154
}
@@ -292,9 +297,9 @@
292297
" const int x_stride0 = x_fsz0*y_fsz0;\n",
293298
" const int y_stride0 = y_fsz0;\n",
294299
"\n",
295-
" float r0 = 1.0F/dt;\n",
296-
" float r1 = 1.0F/(h_x*h_x);\n",
297-
" float r2 = 1.0F/(h_y*h_y);\n",
300+
" const float r0 = 1.0F/dt;\n",
301+
" const float r1 = 1.0F/(h_x*h_x);\n",
302+
" const float r2 = 1.0F/(h_y*h_y);\n",
298303
"\n",
299304
" for (int time = time_m; time <= time_M; time += 1)\n",
300305
" {\n",
@@ -340,7 +345,7 @@
340345
"name": "python",
341346
"nbconvert_exporter": "python",
342347
"pygments_lexer": "ipython3",
343-
"version": "3.13.5"
348+
"version": "3.13.11"
344349
}
345350
},
346351
"nbformat": 4,

tests/test_dse.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2552,6 +2552,7 @@ def test_invariants_with_conditional(self):
25522552
eqn = Eq(u, u - (cos(time_sub * factor * f) * uf))
25532553

25542554
op = Operator(eqn, opt='advanced')
2555+
25552556
assert_structure(op, ['t', 't,fd', 't,fd,x,y'], 't,fd,x,y')
25562557
# Make sure it compiles
25572558
_ = op.cfunction
@@ -2700,6 +2701,8 @@ def test_split_cond(self):
27002701
eq2 = Eq(u.forward, u.forward + cos(time), implicit_dims=ct)
27012702

27022703
op = Operator([eq0, eq1, eq2])
2704+
op(ime=5)
2705+
27032706
cond = FindNodes(Conditional).visit(op)
27042707
assert len(cond) == 3
27052708
# The alias should have been lifted out of the condition
@@ -2721,6 +2724,8 @@ def test_split_cond_multi_alias(self):
27212724
eq2 = Eq(u.forward, u.forward + cos(time) - sin(time), implicit_dims=ct)
27222725

27232726
op = Operator([eq0, eq1, eq2])
2727+
op(ime=5)
2728+
27242729
cond = FindNodes(Conditional).visit(op)
27252730
assert len(cond) == 3
27262731
# The alias should have been lifted out of the condition
@@ -2743,6 +2748,7 @@ def test_multi_cond_no_split(self):
27432748
eq2 = Eq(u.forward, u.forward - sin(time), implicit_dims=ct)
27442749

27452750
op = Operator([eq0, eq1, eq2])
2751+
op(time=5)
27462752

27472753
assert_structure(
27482754
op,
@@ -2751,7 +2757,7 @@ def test_multi_cond_no_split(self):
27512757
)
27522758

27532759
scalars = [i for i in FindSymbols().visit(op) if isinstance(i, Temp)]
2754-
assert len(scalars) == 4
2760+
assert len(scalars) == 3
27552761

27562762
def test_alias_with_conditional(self):
27572763
grid = Grid((11, 11))
@@ -2767,6 +2773,8 @@ def test_alias_with_conditional(self):
27672773
eq2 = Eq(u.forward, u.forward + cos(ct), implicit_dims=ct)
27682774

27692775
op = Operator([eq0, eq1, eq2])
2776+
op(time=5)
2777+
27702778
cond = FindNodes(Conditional).visit(op)
27712779
assert len(cond) == 3
27722780

0 commit comments

Comments
 (0)