@@ -132,7 +132,8 @@ def _aliases_from_clusters(self, cgroup, exclude, meta):
132132 # AliasList -> Schedule
133133 schedule = lower_aliases (aliases , meta , self .opt_maxpar )
134134
135- variants .append (Variant (schedule , exprs ))
135+ if schedule :
136+ variants .append (Variant (schedule , exprs ))
136137
137138 if not variants :
138139 return []
@@ -147,7 +148,7 @@ def _aliases_from_clusters(self, cgroup, exclude, meta):
147148 # Schedule -> [Clusters]_k
148149 processed , subs = lower_schedule (schedule , meta , self .sregistry ,
149150 self .opt_ftemps , self .opt_min_dtype ,
150- self .opt_minmem )
151+ self .opt_minmem , nclusters = len ( cgroup ) )
151152
152153 # [Clusters]_k -> [Clusters]_k (optimization)
153154 if self .opt_multisubdomain :
@@ -272,7 +273,6 @@ def _do_generate(self, exprs, exclude, cbk_search, cbk_compose=None):
272273 free_symbols = i .free_symbols
273274 if {a .function for a in free_symbols } & exclude :
274275 continue
275-
276276 mapper .add (i , make , terms )
277277
278278 return mapper
@@ -853,7 +853,7 @@ def optimize_schedule_rotations(schedule, sregistry):
853853
854854
855855def lower_schedule (schedule , meta , sregistry , opt_ftemps , opt_min_dtype ,
856- opt_minmem ):
856+ opt_minmem , nclusters = 1 ):
857857 """
858858 Turn a Schedule into a sequence of Clusters.
859859 """
@@ -929,20 +929,21 @@ def lower_schedule(schedule, meta, sregistry, opt_ftemps, opt_min_dtype,
929929 # Degenerate case: scalar expression
930930 assert writeto .size == 0
931931
932- guards = None
933932 is_cond = any (isinstance (d , (SubsamplingFactor , ConditionalDimension ))
934933 for d in pivot .free_symbols )
935- if meta .guards and is_cond :
934+ if meta .guards and is_cond and nclusters > 1 :
936935 # Scalar alias that depends on a guard, unsafe to lift out of the guard
937936 # Do not alias
938937 expression = None
939938 callback = lambda idx : uxreplace (pivot , subs ) # noqa: B023
940939 else :
941940 dtype = sympy_dtype (pivot , base = meta .dtype , smin = opt_min_dtype )
942- obj = Temp (name = name , dtype = dtype )
941+ obj = Temp (name = name , dtype = dtype , is_const = True )
943942 expression = Eq (obj , uxreplace (pivot , subs ))
944943
945944 callback = lambda idx : obj # noqa: B023
945+ # Only keep the guard if there is no cross-cluster reuse of the scalar
946+ guards = meta .guards if nclusters == 1 else None
946947
947948 # Create the substitution rules for the aliasing expressions
948949 subs .update ({
0 commit comments