Skip to content

Commit 5495fb4

Browse files
committed
Fix Flax LMS scheduler sigma mismatch and derivatives buffer OOB
Remove the erroneous +1e-5 sigma offset in step() so it matches scale_model_input() and the PyTorch LMSDiscreteScheduler reference. Cap the multistep fori_loop bound by the pre-allocated derivatives buffer size to prevent out-of-bounds indexing when order exceeds max_order from set_timesteps.
1 parent 4eeb570 commit 5495fb4

1 file changed

Lines changed: 3 additions & 2 deletions

File tree

src/diffusers/schedulers/scheduling_lms_discrete_flax.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -271,7 +271,7 @@ def step(
271271
)
272272

273273
step_index = jnp.where(state.timesteps == timestep, jnp.arange(state.timesteps.shape[0]), 0).sum()
274-
sigma = state.sigmas[step_index] + 1e-5
274+
sigma = state.sigmas[step_index]
275275

276276
# 1. compute predicted original sample (x_0) from sigma-scaled predicted noise
277277
if self.config.prediction_type == "epsilon":
@@ -291,9 +291,10 @@ def step(
291291

292292
# 3. Compute linear multistep coefficients and the previous sample based on the derivatives path
293293
effective_order = jnp.minimum(step_index + 1, order)
294+
loop_order = jnp.minimum(order, state.derivatives.shape[0])
294295
prev_sample = jax.lax.fori_loop(
295296
0,
296-
order,
297+
loop_order,
297298
lambda i, val: jnp.where(
298299
i < effective_order,
299300
val + self.get_lms_coefficient(state, effective_order, step_index, i) * state.derivatives[-(i + 1)],

0 commit comments

Comments
 (0)