diff --git a/src/diffusers/schedulers/scheduling_lms_discrete_flax.py b/src/diffusers/schedulers/scheduling_lms_discrete_flax.py index c37d8752f7fb..1abf80218e6c 100644 --- a/src/diffusers/schedulers/scheduling_lms_discrete_flax.py +++ b/src/diffusers/schedulers/scheduling_lms_discrete_flax.py @@ -15,8 +15,8 @@ from dataclasses import dataclass import flax +import jax import jax.numpy as jnp -from scipy import integrate from ..configuration_utils import ConfigMixin, register_to_config from ..utils import logging @@ -151,9 +151,7 @@ def scale_model_input(self, state: LMSDiscreteSchedulerState, sample: jnp.ndarra Returns: `jnp.ndarray`: scaled input sample """ - (step_index,) = jnp.where(state.timesteps == timestep, size=1) - step_index = step_index[0] - + step_index = jnp.where(state.timesteps == timestep, jnp.arange(state.timesteps.shape[0]), 0).sum() sigma = state.sigmas[step_index] sample = sample / ((sigma**2 + 1) ** 0.5) return sample @@ -163,28 +161,42 @@ def get_lms_coefficient(self, state: LMSDiscreteSchedulerState, order, t, curren Compute a linear multistep coefficient. Args: - order (TODO): - t (TODO): - current_order (TODO): + order (`int`): + The order of the linear multistep method. + t (`int`): + The current step index in the inference schedule. + current_order (`int`): + The current order for which to compute the coefficient. """ + num_sigmas = state.sigmas.shape[0] + num_integration_steps = 10 def lms_derivative(tau): - prod = 1.0 - for k in range(order): - if current_order == k: - continue - prod *= (tau - state.sigmas[t - k]) / (state.sigmas[t - current_order] - state.sigmas[t - k]) - return prod + num_tau = tau.shape[0] + mask_indices = jnp.broadcast_to( + jnp.arange(num_sigmas).reshape(1, -1), + (num_tau, num_sigmas), + ) + greater_than = t - order + 1 <= mask_indices + lower_than = mask_indices < t + 1 + not_same_value = mask_indices != t - current_order + mask = greater_than & lower_than & not_same_value - integrated_coeff = integrate.quad(lms_derivative, state.sigmas[t], state.sigmas[t + 1], epsrel=1e-4)[0] + correct_coeffs = (tau.reshape(-1, 1) - state.sigmas.reshape(1, -1)) / ( + state.sigmas[t - current_order] - state.sigmas.reshape(1, -1) + 1e-5 + ) + coeffs = jnp.where(mask, correct_coeffs, jnp.ones_like(mask)) + return jnp.prod(coeffs, axis=1) - return integrated_coeff + x = jnp.linspace(state.sigmas[t], state.sigmas[t + 1], num_integration_steps) + return jnp.trapezoid(lms_derivative(x), x=x, axis=0) def set_timesteps( self, state: LMSDiscreteSchedulerState, num_inference_steps: int, shape: tuple = (), + max_order: int = 4, ) -> LMSDiscreteSchedulerState: """ Sets the timesteps used for the diffusion chain. Supporting function to be run before inference. @@ -194,6 +206,8 @@ def set_timesteps( the `FlaxLMSDiscreteScheduler` state data class instance. num_inference_steps (`int`): the number of diffusion steps used when generating samples with a pre-trained model. + max_order (`int`, defaults to `4`): + The maximum multistep order. Used to pre-allocate the derivatives buffer for jittable inference. """ timesteps = jnp.linspace( @@ -215,7 +229,7 @@ def set_timesteps( timesteps = timesteps.astype(jnp.int32) # initial running values - derivatives = jnp.zeros((0,) + shape, dtype=self.dtype) + derivatives = jnp.zeros((max_order,) + shape, dtype=self.dtype) return state.replace( timesteps=timesteps, @@ -256,7 +270,8 @@ def step( "Number of inference steps is 'None', you need to run 'set_timesteps' after creating the scheduler" ) - sigma = state.sigmas[timestep] + step_index = jnp.where(state.timesteps == timestep, jnp.arange(state.timesteps.shape[0]), 0).sum() + sigma = state.sigmas[step_index] + 1e-5 # 1. compute predicted original sample (x_0) from sigma-scaled predicted noise if self.config.prediction_type == "epsilon": @@ -269,19 +284,22 @@ def step( f"prediction_type given as {self.config.prediction_type} must be one of `epsilon`, or `v_prediction`" ) - # 2. Convert to an ODE derivative + # 2. Convert to an ODE derivative and maintain a fixed-size rolling buffer derivative = (sample - pred_original_sample) / sigma - state = state.replace(derivatives=jnp.append(state.derivatives, derivative)) - if len(state.derivatives) > order: - state = state.replace(derivatives=jnp.delete(state.derivatives, 0)) + derivative = derivative.reshape(1, *derivative.shape).astype(self.dtype) + state = state.replace(derivatives=jnp.concatenate([state.derivatives[1:], derivative], axis=0)) - # 3. Compute linear multistep coefficients - order = min(timestep + 1, order) - lms_coeffs = [self.get_lms_coefficient(state, order, timestep, curr_order) for curr_order in range(order)] - - # 4. Compute previous sample based on the derivatives path - prev_sample = sample + sum( - coeff * derivative for coeff, derivative in zip(lms_coeffs, reversed(state.derivatives)) + # 3. Compute linear multistep coefficients and the previous sample based on the derivatives path + effective_order = jnp.minimum(step_index + 1, order) + prev_sample = jax.lax.fori_loop( + 0, + order, + lambda i, val: jnp.where( + i < effective_order, + val + self.get_lms_coefficient(state, effective_order, step_index, i) * state.derivatives[-(i + 1)], + val, + ), + sample, ) if not return_dict: diff --git a/tests/schedulers/test_scheduler_lms_flax.py b/tests/schedulers/test_scheduler_lms_flax.py new file mode 100644 index 000000000000..afc879af08b7 --- /dev/null +++ b/tests/schedulers/test_scheduler_lms_flax.py @@ -0,0 +1,76 @@ +import jax +import numpy as np +from jax import jit + +from diffusers import FlaxLMSDiscreteScheduler + +from ..testing_utils import require_flax + + +@require_flax +class TestFlaxLMSDiscreteSchedulerJit: + def setup_method(self): + self.scheduler = FlaxLMSDiscreteScheduler( + num_train_timesteps=100, + beta_start=0.0001, + beta_end=0.02, + beta_schedule="linear", + ) + self.state = self.scheduler.create_state() + sample_shape = (2, 4, 4, 3) + self.state = self.scheduler.set_timesteps(self.state, num_inference_steps=10, shape=sample_shape) + rng = jax.random.PRNGKey(0) + self.sample = jax.random.normal(rng, sample_shape) + self.model_output = jax.random.normal(jax.random.PRNGKey(1), sample_shape) + + def test_step_jit_matches_eager(self): + timestep = int(self.state.timesteps[3]) + + eager_out, eager_state = self.scheduler.step( + self.state, self.model_output, timestep, self.sample, return_dict=False + ) + jit_step = jit(self.scheduler.step, static_argnums=(5,)) + jit_out, jit_state = jit_step(self.state, self.model_output, timestep, self.sample, 4, False) + + np.testing.assert_allclose(np.array(eager_out), np.array(jit_out), rtol=1e-5, atol=1e-5) + np.testing.assert_allclose( + np.array(eager_state.derivatives), np.array(jit_state.derivatives), rtol=1e-5, atol=1e-5 + ) + + def test_full_loop_jit_matches_eager(self): + def run_loop(state, sample): + for t in state.timesteps: + t_int = int(t) + sample = self.scheduler.scale_model_input(state, sample, t_int) + model_output = self.model_output + out = self.scheduler.step(state, model_output, t_int, sample, return_dict=True) + sample = out.prev_sample + state = out.state + return sample + + eager_sample = run_loop(self.state, self.sample) + + @jit + def run_loop_jit(state, sample, model_output): + def body(i, carry): + state, sample = carry + t = state.timesteps[i] + sample = self.scheduler.scale_model_input(state, sample, t) + out = self.scheduler.step(state, model_output, t, sample, return_dict=True) + return out.state, out.prev_sample + + state, sample = jax.lax.fori_loop(0, state.timesteps.shape[0], body, (state, sample)) + return sample + + jit_sample = run_loop_jit(self.state, self.sample, self.model_output) + np.testing.assert_allclose(np.array(eager_sample), np.array(jit_sample), rtol=1e-4, atol=1e-4) + + def test_get_lms_coefficient_is_jittable(self): + step_index = 3 + order = 4 + + eager_coeff = self.scheduler.get_lms_coefficient(self.state, order, step_index, 0) + jit_coeff_fn = jit(self.scheduler.get_lms_coefficient) + jit_coeff = jit_coeff_fn(self.state, order, step_index, 0) + + np.testing.assert_allclose(np.array(eager_coeff), np.array(jit_coeff), rtol=1e-5, atol=1e-5)