Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
36 changes: 27 additions & 9 deletions src/diffusers/schedulers/scheduling_flow_match_euler_discrete.py
Original file line number Diff line number Diff line change
Expand Up @@ -219,7 +219,9 @@ def scale_noise(

# self.begin_index is None when scheduler is used for training, or pipeline does not implement set_begin_index
if self.begin_index is None:
step_indices = [self.index_for_timestep(t, schedule_timesteps) for t in timestep]
step_indices = self.index_for_timestep(timestep, schedule_timesteps)
if not torch.is_tensor(step_indices):
step_indices = torch.tensor([step_indices], device=sample.device, dtype=torch.long)
elif self.step_index is not None:
# add_noise is called after first denoising step (for inpainting)
step_indices = [self.step_index] * timestep.shape[0]
Expand Down Expand Up @@ -388,32 +390,48 @@ def index_for_timestep(
self,
timestep: Union[float, torch.FloatTensor],
schedule_timesteps: Optional[torch.FloatTensor] = None,
) -> int:
) -> Union[int, torch.LongTensor]:
"""
Get the index for the given timestep.

Args:
timestep (`float` or `torch.FloatTensor`):
The timestep to find the index for.
The timestep to find the index for. When a 1-D tensor is passed, indices are computed in batch
without per-element ``nonzero()`` calls.
schedule_timesteps (`torch.FloatTensor`, *optional*):
The schedule timesteps to validate against. If `None`, the scheduler's timesteps are used.

Returns:
`int`:
The index of the timestep.
`int` or `torch.LongTensor`:
The index (or indices) of the timestep in the schedule.
"""
if schedule_timesteps is None:
schedule_timesteps = self.timesteps

indices = (schedule_timesteps == timestep).nonzero()
return_scalar = not torch.is_tensor(timestep) or timestep.ndim == 0
if not torch.is_tensor(timestep):
timestep = torch.tensor(timestep, device=schedule_timesteps.device, dtype=schedule_timesteps.dtype)
elif timestep.ndim == 0:
timestep = timestep.reshape(1)

matches = schedule_timesteps.unsqueeze(0) == timestep.unsqueeze(1)
first_idx = matches.int().argmax(dim=1)
num_matches = matches.sum(dim=1)

# The sigma index that is taken for the **very** first `step`
# is always the second index (or the last index if there is only 1)
# This way we can ensure we don't accidentally skip a sigma in
# case we start in the middle of the denoising schedule (e.g. for image-to-image)
pos = 1 if len(indices) > 1 else 0

return indices[pos].item()
matches_after_first = matches.clone()
batch_idx = torch.arange(matches.shape[0], device=matches.device)
matches_after_first[batch_idx, first_idx] = False
second_idx = matches_after_first.int().argmax(dim=1)
has_second = matches_after_first.any(dim=1)
indices = torch.where((num_matches > 1) & has_second, second_idx, first_idx)

if return_scalar:
return indices[0].item()
return indices

def _init_step_index(self, timestep: Union[float, torch.FloatTensor]) -> None:
if self.begin_index is None:
Expand Down
130 changes: 130 additions & 0 deletions tests/schedulers/test_scheduler_flow_match_euler_discrete.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,130 @@
# Copyright 2026 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import time
import unittest

import torch

from diffusers import FlowMatchEulerDiscreteScheduler


def _legacy_index_for_timestep(scheduler, timestep, schedule_timesteps):
"""Reference implementation using per-element nonzero()."""
if not torch.is_tensor(timestep):
timestep = torch.tensor([timestep], device=schedule_timesteps.device, dtype=schedule_timesteps.dtype)
elif timestep.ndim == 0:
timestep = timestep.reshape(1)

indices = []
for t in timestep:
matches = (schedule_timesteps == t).nonzero()
pos = 1 if len(matches) > 1 else 0
indices.append(matches[pos].item())
return indices[0] if len(indices) == 1 else torch.tensor(indices, device=schedule_timesteps.device)


class FlowMatchEulerDiscreteSchedulerTest(unittest.TestCase):
scheduler_class = FlowMatchEulerDiscreteScheduler

def get_default_config(self, **kwargs):
config = {
"num_train_timesteps": 1000,
"shift": 1.0,
}
config.update(**kwargs)
return config

def _assert_index_parity(self, scheduler, schedule_timesteps, timesteps):
for t in timesteps:
expected = _legacy_index_for_timestep(scheduler, t, schedule_timesteps)
actual = scheduler.index_for_timestep(t, schedule_timesteps)
self.assertEqual(actual, expected)

batch = torch.stack([timesteps[0], timesteps[-1], timesteps[len(timesteps) // 2]])
expected_batch = torch.tensor(
[_legacy_index_for_timestep(scheduler, t, schedule_timesteps) for t in batch],
device=schedule_timesteps.device,
)
actual_batch = scheduler.index_for_timestep(batch, schedule_timesteps)
torch.testing.assert_close(actual_batch, expected_batch)

def test_index_for_timestep_even_shift(self):
scheduler = self.scheduler_class(**self.get_default_config(shift=1.0))
scheduler.set_timesteps(num_inference_steps=10)
self._assert_index_parity(scheduler, scheduler.timesteps, scheduler.timesteps)

def test_index_for_timestep_non_uniform_shift(self):
scheduler = self.scheduler_class(**self.get_default_config(shift=3.0))
scheduler.set_timesteps(num_inference_steps=20)
self._assert_index_parity(scheduler, scheduler.timesteps, scheduler.timesteps)

def test_scale_noise_batch_matches_legacy(self):
scheduler = self.scheduler_class(**self.get_default_config(shift=3.0))
scheduler.set_timesteps(num_inference_steps=16)

sample = torch.randn(8, 4, 8, 8)
noise = torch.randn_like(sample)
timesteps = scheduler.timesteps[:8]

out = scheduler.scale_noise(sample, timesteps, noise)

legacy_indices = [_legacy_index_for_timestep(scheduler, t, scheduler.timesteps) for t in timesteps]
sigmas = scheduler.sigmas.to(sample.device, dtype=sample.dtype)
sigma = sigmas[torch.tensor(legacy_indices)].flatten()
while len(sigma.shape) < len(sample.shape):
sigma = sigma.unsqueeze(-1)
expected = sigma * noise + (1.0 - sigma) * sample

torch.testing.assert_close(out, expected)

def test_scale_noise_training_batch_speedup(self):
scheduler = self.scheduler_class(**self.get_default_config(shift=3.0))
scheduler.set_timesteps(num_inference_steps=50)

sample = torch.randn(64, 16, 32, 32)
noise = torch.randn_like(sample)
timesteps = scheduler.timesteps.repeat(64 // scheduler.timesteps.shape[0] + 1)[:64]

warmup = 5
repeats = 50
for _ in range(warmup):
scheduler.scale_noise(sample, timesteps, noise)

start = time.perf_counter()
for _ in range(repeats):
scheduler.scale_noise(sample, timesteps, noise)
optimized = time.perf_counter() - start

schedule_timesteps = scheduler.timesteps.to(sample.device)

def legacy_scale_noise():
step_indices = [_legacy_index_for_timestep(scheduler, t, schedule_timesteps) for t in timesteps]
sigmas = scheduler.sigmas.to(sample.device, dtype=sample.dtype)
sigma = sigmas[torch.tensor(step_indices)].flatten()
while len(sigma.shape) < len(sample.shape):
sigma = sigma.unsqueeze(-1)
return sigma * noise + (1.0 - sigma) * sample

for _ in range(warmup):
legacy_scale_noise()

start = time.perf_counter()
for _ in range(repeats):
legacy_scale_noise()
legacy = time.perf_counter() - start

self.assertLess(
optimized, legacy, msg=f"expected speedup, got optimized={optimized:.4f}s legacy={legacy:.4f}s"
)