Make FlaxLMSDiscreteScheduler jittable (#2180)#8
Conversation
Three require_flax tests verify step, full-loop, and coefficient parity between eager and jax.jit execution after the jittable scheduler refactor. Co-authored-by: Cursor <cursoragent@cursor.com>
There was a problem hiding this comment.
Cursor Bugbot has reviewed your changes using default effort and found 2 potential issues.
Bugbot Autofix prepared fixes for both issues found in the latest run.
- ✅ Fixed: Step sigma epsilon mismatch
- Removed the erroneous
+ 1e-5offset fromstep()so sigma matchesscale_model_input()and the PyTorch LMSDiscreteScheduler reference.
- Removed the erroneous
- ✅ Fixed: Order exceeds buffer size
- Capped the multistep
fori_loopupper bound withjnp.minimum(order, state.derivatives.shape[0])to prevent out-of-bounds derivative indexing whenorderexceeds themax_orderbuffer fromset_timesteps.
- Capped the multistep
Or push these changes by commenting:
@cursor push 5495fb409d
Preview (5495fb409d)
diff --git a/src/diffusers/schedulers/scheduling_lms_discrete_flax.py b/src/diffusers/schedulers/scheduling_lms_discrete_flax.py
--- a/src/diffusers/schedulers/scheduling_lms_discrete_flax.py
+++ b/src/diffusers/schedulers/scheduling_lms_discrete_flax.py
@@ -271,7 +271,7 @@
)
step_index = jnp.where(state.timesteps == timestep, jnp.arange(state.timesteps.shape[0]), 0).sum()
- sigma = state.sigmas[step_index] + 1e-5
+ sigma = state.sigmas[step_index]
# 1. compute predicted original sample (x_0) from sigma-scaled predicted noise
if self.config.prediction_type == "epsilon":
@@ -291,9 +291,10 @@
# 3. Compute linear multistep coefficients and the previous sample based on the derivatives path
effective_order = jnp.minimum(step_index + 1, order)
+ loop_order = jnp.minimum(order, state.derivatives.shape[0])
prev_sample = jax.lax.fori_loop(
0,
- order,
+ loop_order,
lambda i, val: jnp.where(
i < effective_order,
val + self.get_lms_coefficient(state, effective_order, step_index, i) * state.derivatives[-(i + 1)],You can send follow-ups to the cloud agent here.
Reviewed by Cursor Bugbot for commit 4eeb570. Configure here.
|
|
||
| sigma = state.sigmas[timestep] | ||
| step_index = jnp.where(state.timesteps == timestep, jnp.arange(state.timesteps.shape[0]), 0).sum() | ||
| sigma = state.sigmas[step_index] + 1e-5 |
There was a problem hiding this comment.
Step sigma epsilon mismatch
Medium Severity
step uses state.sigmas[step_index] + 1e-5 for denoising and the ODE derivative, while scale_model_input uses the same index without the offset. Pipelines scale inputs then call step with the same timestep, so the two paths disagree on noise level versus the PyTorch LMS reference.
Additional Locations (1)
Reviewed by Cursor Bugbot for commit 4eeb570. Configure here.
| order, | ||
| lambda i, val: jnp.where( | ||
| i < effective_order, | ||
| val + self.get_lms_coefficient(state, effective_order, step_index, i) * state.derivatives[-(i + 1)], |
There was a problem hiding this comment.
Order exceeds buffer size
Medium Severity
step loops up to its order argument and indexes state.derivatives[-(i + 1)], but set_timesteps only allocates max_order rows (default 4). Calling step with order greater than max_order without matching set_timesteps causes out-of-bounds indexing; previously the derivatives list could grow with order.
Additional Locations (1)
Reviewed by Cursor Bugbot for commit 4eeb570. Configure here.



Summary
get_lms_coefficientwith JAX-nativejnp.trapezoidand vectorized coefficient productmax_orderinset_timestepsjax.lax.fori_loopand step-index sigma lookup instepfor jit compatibility@require_flaxparity tests (step, full loop, coefficient)Fixes huggingface#2180
Test plan
pytest tests/schedulers/test_scheduler_lms_flax.py -q(3 passed)python utils/check_copies.pyMade with Cursor