diff --git a/src/diffusers/schedulers/scheduling_flow_match_euler_discrete.py b/src/diffusers/schedulers/scheduling_flow_match_euler_discrete.py index 7b207f782079..a02c5294cf7f 100644 --- a/src/diffusers/schedulers/scheduling_flow_match_euler_discrete.py +++ b/src/diffusers/schedulers/scheduling_flow_match_euler_discrete.py @@ -219,7 +219,9 @@ def scale_noise( # self.begin_index is None when scheduler is used for training, or pipeline does not implement set_begin_index if self.begin_index is None: - step_indices = [self.index_for_timestep(t, schedule_timesteps) for t in timestep] + step_indices = self.index_for_timestep(timestep, schedule_timesteps) + if not torch.is_tensor(step_indices): + step_indices = torch.tensor([step_indices], device=sample.device, dtype=torch.long) elif self.step_index is not None: # add_noise is called after first denoising step (for inpainting) step_indices = [self.step_index] * timestep.shape[0] @@ -388,32 +390,48 @@ def index_for_timestep( self, timestep: Union[float, torch.FloatTensor], schedule_timesteps: Optional[torch.FloatTensor] = None, - ) -> int: + ) -> Union[int, torch.LongTensor]: """ Get the index for the given timestep. Args: timestep (`float` or `torch.FloatTensor`): - The timestep to find the index for. + The timestep to find the index for. When a 1-D tensor is passed, indices are computed in batch + without per-element ``nonzero()`` calls. schedule_timesteps (`torch.FloatTensor`, *optional*): The schedule timesteps to validate against. If `None`, the scheduler's timesteps are used. Returns: - `int`: - The index of the timestep. + `int` or `torch.LongTensor`: + The index (or indices) of the timestep in the schedule. """ if schedule_timesteps is None: schedule_timesteps = self.timesteps - indices = (schedule_timesteps == timestep).nonzero() + return_scalar = not torch.is_tensor(timestep) or timestep.ndim == 0 + if not torch.is_tensor(timestep): + timestep = torch.tensor(timestep, device=schedule_timesteps.device, dtype=schedule_timesteps.dtype) + elif timestep.ndim == 0: + timestep = timestep.reshape(1) + + matches = schedule_timesteps.unsqueeze(0) == timestep.unsqueeze(1) + first_idx = matches.int().argmax(dim=1) + num_matches = matches.sum(dim=1) # The sigma index that is taken for the **very** first `step` # is always the second index (or the last index if there is only 1) # This way we can ensure we don't accidentally skip a sigma in # case we start in the middle of the denoising schedule (e.g. for image-to-image) - pos = 1 if len(indices) > 1 else 0 - - return indices[pos].item() + matches_after_first = matches.clone() + batch_idx = torch.arange(matches.shape[0], device=matches.device) + matches_after_first[batch_idx, first_idx] = False + second_idx = matches_after_first.int().argmax(dim=1) + has_second = matches_after_first.any(dim=1) + indices = torch.where((num_matches > 1) & has_second, second_idx, first_idx) + + if return_scalar: + return indices[0].item() + return indices def _init_step_index(self, timestep: Union[float, torch.FloatTensor]) -> None: if self.begin_index is None: diff --git a/tests/schedulers/test_scheduler_flow_match_euler_discrete.py b/tests/schedulers/test_scheduler_flow_match_euler_discrete.py new file mode 100644 index 000000000000..3400339035e6 --- /dev/null +++ b/tests/schedulers/test_scheduler_flow_match_euler_discrete.py @@ -0,0 +1,130 @@ +# Copyright 2026 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import time +import unittest + +import torch + +from diffusers import FlowMatchEulerDiscreteScheduler + + +def _legacy_index_for_timestep(scheduler, timestep, schedule_timesteps): + """Reference implementation using per-element nonzero().""" + if not torch.is_tensor(timestep): + timestep = torch.tensor([timestep], device=schedule_timesteps.device, dtype=schedule_timesteps.dtype) + elif timestep.ndim == 0: + timestep = timestep.reshape(1) + + indices = [] + for t in timestep: + matches = (schedule_timesteps == t).nonzero() + pos = 1 if len(matches) > 1 else 0 + indices.append(matches[pos].item()) + return indices[0] if len(indices) == 1 else torch.tensor(indices, device=schedule_timesteps.device) + + +class FlowMatchEulerDiscreteSchedulerTest(unittest.TestCase): + scheduler_class = FlowMatchEulerDiscreteScheduler + + def get_default_config(self, **kwargs): + config = { + "num_train_timesteps": 1000, + "shift": 1.0, + } + config.update(**kwargs) + return config + + def _assert_index_parity(self, scheduler, schedule_timesteps, timesteps): + for t in timesteps: + expected = _legacy_index_for_timestep(scheduler, t, schedule_timesteps) + actual = scheduler.index_for_timestep(t, schedule_timesteps) + self.assertEqual(actual, expected) + + batch = torch.stack([timesteps[0], timesteps[-1], timesteps[len(timesteps) // 2]]) + expected_batch = torch.tensor( + [_legacy_index_for_timestep(scheduler, t, schedule_timesteps) for t in batch], + device=schedule_timesteps.device, + ) + actual_batch = scheduler.index_for_timestep(batch, schedule_timesteps) + torch.testing.assert_close(actual_batch, expected_batch) + + def test_index_for_timestep_even_shift(self): + scheduler = self.scheduler_class(**self.get_default_config(shift=1.0)) + scheduler.set_timesteps(num_inference_steps=10) + self._assert_index_parity(scheduler, scheduler.timesteps, scheduler.timesteps) + + def test_index_for_timestep_non_uniform_shift(self): + scheduler = self.scheduler_class(**self.get_default_config(shift=3.0)) + scheduler.set_timesteps(num_inference_steps=20) + self._assert_index_parity(scheduler, scheduler.timesteps, scheduler.timesteps) + + def test_scale_noise_batch_matches_legacy(self): + scheduler = self.scheduler_class(**self.get_default_config(shift=3.0)) + scheduler.set_timesteps(num_inference_steps=16) + + sample = torch.randn(8, 4, 8, 8) + noise = torch.randn_like(sample) + timesteps = scheduler.timesteps[:8] + + out = scheduler.scale_noise(sample, timesteps, noise) + + legacy_indices = [_legacy_index_for_timestep(scheduler, t, scheduler.timesteps) for t in timesteps] + sigmas = scheduler.sigmas.to(sample.device, dtype=sample.dtype) + sigma = sigmas[torch.tensor(legacy_indices)].flatten() + while len(sigma.shape) < len(sample.shape): + sigma = sigma.unsqueeze(-1) + expected = sigma * noise + (1.0 - sigma) * sample + + torch.testing.assert_close(out, expected) + + def test_scale_noise_training_batch_speedup(self): + scheduler = self.scheduler_class(**self.get_default_config(shift=3.0)) + scheduler.set_timesteps(num_inference_steps=50) + + sample = torch.randn(64, 16, 32, 32) + noise = torch.randn_like(sample) + timesteps = scheduler.timesteps.repeat(64 // scheduler.timesteps.shape[0] + 1)[:64] + + warmup = 5 + repeats = 50 + for _ in range(warmup): + scheduler.scale_noise(sample, timesteps, noise) + + start = time.perf_counter() + for _ in range(repeats): + scheduler.scale_noise(sample, timesteps, noise) + optimized = time.perf_counter() - start + + schedule_timesteps = scheduler.timesteps.to(sample.device) + + def legacy_scale_noise(): + step_indices = [_legacy_index_for_timestep(scheduler, t, schedule_timesteps) for t in timesteps] + sigmas = scheduler.sigmas.to(sample.device, dtype=sample.dtype) + sigma = sigmas[torch.tensor(step_indices)].flatten() + while len(sigma.shape) < len(sample.shape): + sigma = sigma.unsqueeze(-1) + return sigma * noise + (1.0 - sigma) * sample + + for _ in range(warmup): + legacy_scale_noise() + + start = time.perf_counter() + for _ in range(repeats): + legacy_scale_noise() + legacy = time.perf_counter() - start + + self.assertLess( + optimized, legacy, msg=f"expected speedup, got optimized={optimized:.4f}s legacy={legacy:.4f}s" + )