Skip to content

Make FlaxLMSDiscreteScheduler jittable (#2180)#8

Open
srlynch1 wants to merge 2 commits into
mainfrom
e2e/diffusers-2180
Open

Make FlaxLMSDiscreteScheduler jittable (#2180)#8
srlynch1 wants to merge 2 commits into
mainfrom
e2e/diffusers-2180

Conversation

@srlynch1

Copy link
Copy Markdown
Owner

Summary

  • Replace scipy integration in get_lms_coefficient with JAX-native jnp.trapezoid and vectorized coefficient product
  • Pre-allocate fixed-shape derivatives buffer via max_order in set_timesteps
  • Use jax.lax.fori_loop and step-index sigma lookup in step for jit compatibility
  • Add three @require_flax parity tests (step, full loop, coefficient)

Fixes huggingface#2180

Test plan

  • pytest tests/schedulers/test_scheduler_lms_flax.py -q (3 passed)
  • python utils/check_copies.py
  • ruff check on changed files

Made with Cursor

srlynch1 and others added 2 commits June 21, 2026 21:43
Three require_flax tests verify step, full-loop, and coefficient parity
between eager and jax.jit execution after the jittable scheduler refactor.

Co-authored-by: Cursor <cursoragent@cursor.com>

@cursor cursor Bot left a comment

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Cursor Bugbot has reviewed your changes using default effort and found 2 potential issues.

Fix All in Cursor

Bugbot Autofix prepared fixes for both issues found in the latest run.

  • ✅ Fixed: Step sigma epsilon mismatch
    • Removed the erroneous + 1e-5 offset from step() so sigma matches scale_model_input() and the PyTorch LMSDiscreteScheduler reference.
  • ✅ Fixed: Order exceeds buffer size
    • Capped the multistep fori_loop upper bound with jnp.minimum(order, state.derivatives.shape[0]) to prevent out-of-bounds derivative indexing when order exceeds the max_order buffer from set_timesteps.

Create PR

Or push these changes by commenting:

@cursor push 5495fb409d
Preview (5495fb409d)
diff --git a/src/diffusers/schedulers/scheduling_lms_discrete_flax.py b/src/diffusers/schedulers/scheduling_lms_discrete_flax.py
--- a/src/diffusers/schedulers/scheduling_lms_discrete_flax.py
+++ b/src/diffusers/schedulers/scheduling_lms_discrete_flax.py
@@ -271,7 +271,7 @@
             )
 
         step_index = jnp.where(state.timesteps == timestep, jnp.arange(state.timesteps.shape[0]), 0).sum()
-        sigma = state.sigmas[step_index] + 1e-5
+        sigma = state.sigmas[step_index]
 
         # 1. compute predicted original sample (x_0) from sigma-scaled predicted noise
         if self.config.prediction_type == "epsilon":
@@ -291,9 +291,10 @@
 
         # 3. Compute linear multistep coefficients and the previous sample based on the derivatives path
         effective_order = jnp.minimum(step_index + 1, order)
+        loop_order = jnp.minimum(order, state.derivatives.shape[0])
         prev_sample = jax.lax.fori_loop(
             0,
-            order,
+            loop_order,
             lambda i, val: jnp.where(
                 i < effective_order,
                 val + self.get_lms_coefficient(state, effective_order, step_index, i) * state.derivatives[-(i + 1)],

You can send follow-ups to the cloud agent here.

Reviewed by Cursor Bugbot for commit 4eeb570. Configure here.


sigma = state.sigmas[timestep]
step_index = jnp.where(state.timesteps == timestep, jnp.arange(state.timesteps.shape[0]), 0).sum()
sigma = state.sigmas[step_index] + 1e-5

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Step sigma epsilon mismatch

Medium Severity

step uses state.sigmas[step_index] + 1e-5 for denoising and the ODE derivative, while scale_model_input uses the same index without the offset. Pipelines scale inputs then call step with the same timestep, so the two paths disagree on noise level versus the PyTorch LMS reference.

Additional Locations (1)
Fix in Cursor Fix in Web

Reviewed by Cursor Bugbot for commit 4eeb570. Configure here.

order,
lambda i, val: jnp.where(
i < effective_order,
val + self.get_lms_coefficient(state, effective_order, step_index, i) * state.derivatives[-(i + 1)],

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Order exceeds buffer size

Medium Severity

step loops up to its order argument and indexes state.derivatives[-(i + 1)], but set_timesteps only allocates max_order rows (default 4). Calling step with order greater than max_order without matching set_timesteps causes out-of-bounds indexing; previously the derivatives list could grow with order.

Additional Locations (1)
Fix in Cursor Fix in Web

Reviewed by Cursor Bugbot for commit 4eeb570. Configure here.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

Make FlaxLMSDiscreteScheduler jittable

1 participant