Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
74 changes: 46 additions & 28 deletions src/diffusers/schedulers/scheduling_lms_discrete_flax.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,8 @@
from dataclasses import dataclass

import flax
import jax
import jax.numpy as jnp
from scipy import integrate

from ..configuration_utils import ConfigMixin, register_to_config
from ..utils import logging
Expand Down Expand Up @@ -151,9 +151,7 @@ def scale_model_input(self, state: LMSDiscreteSchedulerState, sample: jnp.ndarra
Returns:
`jnp.ndarray`: scaled input sample
"""
(step_index,) = jnp.where(state.timesteps == timestep, size=1)
step_index = step_index[0]

step_index = jnp.where(state.timesteps == timestep, jnp.arange(state.timesteps.shape[0]), 0).sum()
sigma = state.sigmas[step_index]
sample = sample / ((sigma**2 + 1) ** 0.5)
return sample
Expand All @@ -163,28 +161,42 @@ def get_lms_coefficient(self, state: LMSDiscreteSchedulerState, order, t, curren
Compute a linear multistep coefficient.

Args:
order (TODO):
t (TODO):
current_order (TODO):
order (`int`):
The order of the linear multistep method.
t (`int`):
The current step index in the inference schedule.
current_order (`int`):
The current order for which to compute the coefficient.
"""
num_sigmas = state.sigmas.shape[0]
num_integration_steps = 10

def lms_derivative(tau):
prod = 1.0
for k in range(order):
if current_order == k:
continue
prod *= (tau - state.sigmas[t - k]) / (state.sigmas[t - current_order] - state.sigmas[t - k])
return prod
num_tau = tau.shape[0]
mask_indices = jnp.broadcast_to(
jnp.arange(num_sigmas).reshape(1, -1),
(num_tau, num_sigmas),
)
greater_than = t - order + 1 <= mask_indices
lower_than = mask_indices < t + 1
not_same_value = mask_indices != t - current_order
mask = greater_than & lower_than & not_same_value

integrated_coeff = integrate.quad(lms_derivative, state.sigmas[t], state.sigmas[t + 1], epsrel=1e-4)[0]
correct_coeffs = (tau.reshape(-1, 1) - state.sigmas.reshape(1, -1)) / (
state.sigmas[t - current_order] - state.sigmas.reshape(1, -1) + 1e-5
)
coeffs = jnp.where(mask, correct_coeffs, jnp.ones_like(mask))
return jnp.prod(coeffs, axis=1)

return integrated_coeff
x = jnp.linspace(state.sigmas[t], state.sigmas[t + 1], num_integration_steps)
return jnp.trapezoid(lms_derivative(x), x=x, axis=0)

def set_timesteps(
self,
state: LMSDiscreteSchedulerState,
num_inference_steps: int,
shape: tuple = (),
max_order: int = 4,
) -> LMSDiscreteSchedulerState:
"""
Sets the timesteps used for the diffusion chain. Supporting function to be run before inference.
Expand All @@ -194,6 +206,8 @@ def set_timesteps(
the `FlaxLMSDiscreteScheduler` state data class instance.
num_inference_steps (`int`):
the number of diffusion steps used when generating samples with a pre-trained model.
max_order (`int`, defaults to `4`):
The maximum multistep order. Used to pre-allocate the derivatives buffer for jittable inference.
"""

timesteps = jnp.linspace(
Expand All @@ -215,7 +229,7 @@ def set_timesteps(
timesteps = timesteps.astype(jnp.int32)

# initial running values
derivatives = jnp.zeros((0,) + shape, dtype=self.dtype)
derivatives = jnp.zeros((max_order,) + shape, dtype=self.dtype)

return state.replace(
timesteps=timesteps,
Expand Down Expand Up @@ -256,7 +270,8 @@ def step(
"Number of inference steps is 'None', you need to run 'set_timesteps' after creating the scheduler"
)

sigma = state.sigmas[timestep]
step_index = jnp.where(state.timesteps == timestep, jnp.arange(state.timesteps.shape[0]), 0).sum()
sigma = state.sigmas[step_index] + 1e-5

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Step sigma epsilon mismatch

Medium Severity

step uses state.sigmas[step_index] + 1e-5 for denoising and the ODE derivative, while scale_model_input uses the same index without the offset. Pipelines scale inputs then call step with the same timestep, so the two paths disagree on noise level versus the PyTorch LMS reference.

Additional Locations (1)
Fix in Cursor Fix in Web

Reviewed by Cursor Bugbot for commit 4eeb570. Configure here.


# 1. compute predicted original sample (x_0) from sigma-scaled predicted noise
if self.config.prediction_type == "epsilon":
Expand All @@ -269,19 +284,22 @@ def step(
f"prediction_type given as {self.config.prediction_type} must be one of `epsilon`, or `v_prediction`"
)

# 2. Convert to an ODE derivative
# 2. Convert to an ODE derivative and maintain a fixed-size rolling buffer
derivative = (sample - pred_original_sample) / sigma
state = state.replace(derivatives=jnp.append(state.derivatives, derivative))
if len(state.derivatives) > order:
state = state.replace(derivatives=jnp.delete(state.derivatives, 0))
derivative = derivative.reshape(1, *derivative.shape).astype(self.dtype)
state = state.replace(derivatives=jnp.concatenate([state.derivatives[1:], derivative], axis=0))

# 3. Compute linear multistep coefficients
order = min(timestep + 1, order)
lms_coeffs = [self.get_lms_coefficient(state, order, timestep, curr_order) for curr_order in range(order)]

# 4. Compute previous sample based on the derivatives path
prev_sample = sample + sum(
coeff * derivative for coeff, derivative in zip(lms_coeffs, reversed(state.derivatives))
# 3. Compute linear multistep coefficients and the previous sample based on the derivatives path
effective_order = jnp.minimum(step_index + 1, order)
prev_sample = jax.lax.fori_loop(
0,
order,
lambda i, val: jnp.where(
i < effective_order,
val + self.get_lms_coefficient(state, effective_order, step_index, i) * state.derivatives[-(i + 1)],

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Order exceeds buffer size

Medium Severity

step loops up to its order argument and indexes state.derivatives[-(i + 1)], but set_timesteps only allocates max_order rows (default 4). Calling step with order greater than max_order without matching set_timesteps causes out-of-bounds indexing; previously the derivatives list could grow with order.

Additional Locations (1)
Fix in Cursor Fix in Web

Reviewed by Cursor Bugbot for commit 4eeb570. Configure here.

val,
),
sample,
)

if not return_dict:
Expand Down
76 changes: 76 additions & 0 deletions tests/schedulers/test_scheduler_lms_flax.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,76 @@
import jax
import numpy as np
from jax import jit

from diffusers import FlaxLMSDiscreteScheduler

from ..testing_utils import require_flax


@require_flax
class TestFlaxLMSDiscreteSchedulerJit:
def setup_method(self):
self.scheduler = FlaxLMSDiscreteScheduler(
num_train_timesteps=100,
beta_start=0.0001,
beta_end=0.02,
beta_schedule="linear",
)
self.state = self.scheduler.create_state()
sample_shape = (2, 4, 4, 3)
self.state = self.scheduler.set_timesteps(self.state, num_inference_steps=10, shape=sample_shape)
rng = jax.random.PRNGKey(0)
self.sample = jax.random.normal(rng, sample_shape)
self.model_output = jax.random.normal(jax.random.PRNGKey(1), sample_shape)

def test_step_jit_matches_eager(self):
timestep = int(self.state.timesteps[3])

eager_out, eager_state = self.scheduler.step(
self.state, self.model_output, timestep, self.sample, return_dict=False
)
jit_step = jit(self.scheduler.step, static_argnums=(5,))
jit_out, jit_state = jit_step(self.state, self.model_output, timestep, self.sample, 4, False)

np.testing.assert_allclose(np.array(eager_out), np.array(jit_out), rtol=1e-5, atol=1e-5)
np.testing.assert_allclose(
np.array(eager_state.derivatives), np.array(jit_state.derivatives), rtol=1e-5, atol=1e-5
)

def test_full_loop_jit_matches_eager(self):
def run_loop(state, sample):
for t in state.timesteps:
t_int = int(t)
sample = self.scheduler.scale_model_input(state, sample, t_int)
model_output = self.model_output
out = self.scheduler.step(state, model_output, t_int, sample, return_dict=True)
sample = out.prev_sample
state = out.state
return sample

eager_sample = run_loop(self.state, self.sample)

@jit
def run_loop_jit(state, sample, model_output):
def body(i, carry):
state, sample = carry
t = state.timesteps[i]
sample = self.scheduler.scale_model_input(state, sample, t)
out = self.scheduler.step(state, model_output, t, sample, return_dict=True)
return out.state, out.prev_sample

state, sample = jax.lax.fori_loop(0, state.timesteps.shape[0], body, (state, sample))
return sample

jit_sample = run_loop_jit(self.state, self.sample, self.model_output)
np.testing.assert_allclose(np.array(eager_sample), np.array(jit_sample), rtol=1e-4, atol=1e-4)

def test_get_lms_coefficient_is_jittable(self):
step_index = 3
order = 4

eager_coeff = self.scheduler.get_lms_coefficient(self.state, order, step_index, 0)
jit_coeff_fn = jit(self.scheduler.get_lms_coefficient)
jit_coeff = jit_coeff_fn(self.state, order, step_index, 0)

np.testing.assert_allclose(np.array(eager_coeff), np.array(jit_coeff), rtol=1e-5, atol=1e-5)