[trainer,hparams,examples,docs] feat: add DGPO (Direct Group Preference Optimization) trainer#133
Conversation
There was a problem hiding this comment.
Pull request overview
Adds a new DGPO (Direct Group Preference Optimization) trainer to Flow-Factory, plus supporting RNG utilities and documentation/example wiring.
Changes:
- Extend
create_generatorandTimeSamplerAPIs to support optional device/generator inputs for deterministic draws. - Add
DGPOTrainer,DGPOTrainingArguments, and register them under thedgpotrainer/hparams registries. - Add a runnable DGPO example config and update guidance docs to document DGPO and list it in trainer/algorithm tables.
Reviewed changes
Copilot reviewed 11 out of 11 changed files in this pull request and generated 2 comments.
Show a summary per file
| File | Description |
|---|---|
src/flow_factory/utils/noise_schedule.py |
Adds optional generator plumbing to timestep samplers for deterministic/cross-rank draws. |
src/flow_factory/utils/base.py |
Extends create_generator with an optional device argument. |
src/flow_factory/trainers/registry.py |
Registers the new dgpo trainer entry. |
src/flow_factory/trainers/dgpo.py |
Implements the DGPO training loop, EMA-ref logic, shared RNG/noise, and group loss. |
src/flow_factory/hparams/training_args.py |
Introduces DGPOTrainingArguments and registers them under dgpo. |
src/flow_factory/hparams/__init__.py |
Exports DGPOTrainingArguments. |
guidance/workflow.md |
Updates workflow tables/notes to include DGPO behavior. |
guidance/algorithms.md |
Adds a DGPO section with hyperparameter documentation and references. |
examples/dgpo/lora/sd3_5.yaml |
Provides an end-to-end DGPO example config (SD3.5 + PickScore). |
AGENTS.md |
Adds DGPO to the listed algorithms and updates docs pointers. |
.agents/knowledge/architecture.md |
Adds DGPO to the trainer registry table. |
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
| for inner_epoch in range(self.training_args.num_inner_epochs): | ||
| perm_gen = create_generator(self.training_args.seed, self.epoch, inner_epoch) | ||
| perm = torch.randperm(len(samples), generator=perm_gen) | ||
| shuffled = [samples[i] for i in perm] | ||
|
|
||
| sample_slices = [shuffled[i : i + bsz] for i in range(0, len(shuffled), bsz)] | ||
|
|
||
| # Pre-compute per-minibatch inputs that do not depend on the current | ||
| # policy: stacked batches, shared timesteps, group info. Training | ||
| # noise is generated per micro-step directly in this loop. | ||
| # Keeping a list of dicts (instead of mutating the stacked batch in | ||
| # place) avoids polluting the batch with training-only keys. | ||
| shared_timesteps = self._sample_shared_timesteps(inner_epoch) # (T,) | ||
| training_batches = self._build_training_batches( | ||
| sample_slices, shared_timesteps, inner_epoch | ||
| ) | ||
|
|
There was a problem hiding this comment.
DGPO’s group-level weighting assumes each unique_id’s full group contributes to global_sums, but optimize() shuffles/reshards individual samples and then forms sample_slices by per_device_batch_size. With the current samplers (e.g. DistributedKRepeatSampler shuffles all m*k indices each epoch), members of the same group will generally be split across multiple slices/steps, so _compute_group_dgpo_loss() will compute sigmoid weights from partial group sums and change the objective.
To make the group loss correct, batch/iterate in a way that keeps complete groups together (or do a two-pass per timestep: first accumulate per-sample deltas into per-group sums across all samples, then apply the resulting group weights when forming the final DSM loss).
There was a problem hiding this comment.
Thanks — I audited this against the upstream reference (Luo-Yihong/DGPO: scripts/train_sd3_dgpo.py):
- Upstream reshapes samples into
train_batch_size=12-sized micro-batches at L1048-1056 withK = num_image_per_prompt = 24, i.e. the sameK > bszsituation you flagged here. - Upstream's
compute_group_dgpo_loss_allreduce(L162-193) runs the scatter-add +dist.all_reduce+sigmoidinside each micro-batch, so it too produces "partial-group sigmoid weights" per step.
So the per-micro-batch stochastic group weight is the intended objective, not a corruption — it's the standard SGD approximation of the group preference loss and matches the paper's released code byte-for-byte. The per-sample advantages carried into the sigmoid have already been normalized over the complete group in prepare_feedback, so the stochastic sigmoid reweight stays well-behaved. Validated end-to-end on SD3.5-medium + PickScore — reward trends upward.
Pushed an amend (fee9e36) that expands _compute_group_dgpo_loss's docstring to make the per-micro-batch stochastic semantics explicit and cites the reference implementation lines, so future readers don't have to re-derive it.
There was a problem hiding this comment.
Thanks — I audited this against the upstream reference (Luo-Yihong/DGPO: scripts/train_sd3_dgpo.py):
- Upstream reshapes samples into
train_batch_size=12-sized micro-batches at L1048-1056 withK = num_image_per_prompt = 24, i.e. the sameK > bszsituation you flagged here. - Upstream's
compute_group_dgpo_loss_allreduce(L162-193) runs the scatter-add +dist.all_reduce+sigmoidinside each micro-batch, so it too produces "partial-group sigmoid weights" per step.
So the per-micro-batch stochastic group weight is the intended objective, not a corruption — it's the standard SGD approximation of the group preference loss and matches the paper's released code byte-for-byte. The per-sample advantages carried into the sigmoid have already been normalized over the complete group in prepare_feedback, so the stochastic sigmoid reweight stays well-behaved. Validated end-to-end on SD3.5-medium + PickScore — reward trends upward.
Pushed an amend (fee9e36) that expands _compute_group_dgpo_loss's docstring to make the per-micro-batch stochastic semantics explicit and cites the reference implementation lines, so future readers don't have to re-derive it.
| ema_ref_device: str = field( | ||
| default='cuda', | ||
| metadata={"help": "Device for old-policy EMA ref parameters ('cuda' or 'cpu')."}, | ||
| ) | ||
|
|
||
| # Timestep control | ||
| num_train_timesteps: int = field( | ||
| default=0, | ||
| metadata={"help": "Number of training timesteps per sample. 0 defaults to `int(num_inference_steps * (timestep_range[1] - timestep_range[0]))`."}, | ||
| ) | ||
| time_sampling_strategy: Literal['uniform', 'logit_normal', 'discrete', 'discrete_with_init', 'discrete_wo_init'] = field( | ||
| default='discrete', | ||
| metadata={"help": "Strategy for sampling training timesteps."}, | ||
| ) | ||
| time_shift: float = field( | ||
| default=3.0, | ||
| metadata={"help": "Shift parameter for logit-normal timestep sampling."}, | ||
| ) | ||
| timestep_range: Union[float, Tuple[float, float]] = field( | ||
| default=0.6, | ||
| metadata={"help": "Timestep range for discrete sampling. Float for [0, value], tuple for [start, end]."}, | ||
| ) | ||
|
|
||
| def __post_init__(self): | ||
| super().__post_init__() | ||
| self.timestep_range = _standardize_timestep_range(self.timestep_range) | ||
| if not self.num_train_timesteps or self.num_train_timesteps <= 0: | ||
| self.num_train_timesteps = max(1, int(self.num_inference_steps * (self.timestep_range[1] - self.timestep_range[0]))) | ||
|
|
There was a problem hiding this comment.
ema_ref_device is typed as a free-form str, and DGPOTrainer currently treats anything other than 'cuda' as CPU. This can silently misconfigure runs (e.g. 'cuda:0' or typo) and hurt performance.
Consider validating in DGPOTrainingArguments.__post_init__ (or typing as Literal['cpu','cuda']) and raising a clear error for unsupported values, so misconfiguration fails fast.
There was a problem hiding this comment.
Good catch — every other device field in training_args.py is already Literal["cpu", "cuda"] (ema_device L220, ref_param_device ×4 at L413/465/556/614). ema_ref_device: str was an inconsistency left over from an earlier iteration. Fixed in fee9e36 — now Literal["cpu", "cuda"]. The DGPOTrainer.__init__ consumer only branches on == "cuda" so there's no behavior change, just stronger static typing + CLI arg validation via dataclass.
…wo-pass Fix a stochastic-gradient artefact in the DGPO group loss: the previous implementation's ``sigmoid(group_sums)`` was built from **partial group sums** because (a) ``DistributedKRepeatSampler`` spreads a group's K copies across ranks, (b) the in-rank ``randperm`` at the top of ``optimize()`` further splits a group across local micro-batches, and (c) ``_compute_group_dgpo_loss`` was invoked per micro-batch. Every optimizer step then saw a different, random subset of each group — Copilot flagged the symptom on PR #133 L644-660. The fix replaces the per-minibatch shuffle + per-minibatch sigmoid with a single, uniform front-end: 1. ``_gather_and_reshard()`` at the start of every ``optimize()`` call runs ``gather_samples`` (reusing ``DPOTrainer``'s pattern at ``trainers/dpo.py:253``), buckets samples by ``unique_id``, permutes whole groups with a cross-rank-consistent seed, and reshards with ``groups[rank::world_size]``. After this every rank owns an integer number of **complete** groups, regardless of sampler topology. 2. With that invariant, two code paths cover every batch geometry: - ``_optimize_exact`` (Path A, ``B % K == 0``): each local micro-batch packs ``B/K`` whole groups, so ``_compute_group_dgpo_loss`` computes ``sigmoid`` on full-group sums in a single forward. No cross-rank reduce needed. - ``_optimize_two_pass`` (Path B, otherwise — the common case for ``B < K``): outer loop is ``for t_idx``, inner runs pass 1 (no-grad forward over every micro-batch, ``scatter_add`` into a rank-wide ``global_sums``, freeze ``sigmoid``) then pass 2 (grad forward + backward consuming the frozen weights via ``_apply_group_weighted_loss``). Pass 1 / pass 2 bit-identical noise is guaranteed by force-enabling ``use_shared_noise=True`` in this path. Side effects: - ``_precompute_group_info`` loses its ``distributed_k_repeat`` gather branch (resharded samples are always rank-local for group purposes). - ``_compute_group_dgpo_loss`` loses the ``_needs_cross_rank_reduce`` dispatch — the sum is always local after reshard. - ``_group_on_same_rank`` property retained for parity with ``AdvantageProcessor``, but no longer consumed by the training loop. - New rank-0 log line at ``__init__`` reports which path is active. - ``guidance/algorithms.md#dgpo`` gains a "Group Completeness" subsection explaining the path selection and why ``B % K == 0`` unlocks the fast path. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
…-rank RNG Backward-compatible additions only; no existing call site is touched. - ``create_generator(*args, device=)`` — optional device kwarg creates Generator directly on GPU, avoiding CPU→GPU copy for randn_tensor. - ``TimeSampler`` methods gain an optional ``generator`` kwarg threaded through logit_normal_shifted / uniform / discrete / _stratified_sample / _raw_logit_normal_unit, plus an ``_rng_device`` helper that routes internal random ops to the generator's device when present. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
…nment split
New sampler and hparams infrastructure for DGPO (no trainer changes yet).
GroupDistributedSampler (data_utils/sampler.py):
- Each rank yields the same prompt-index sequence; each prompt appears
group_size/num_replicas times per rank. Rollout divergence comes from
per-rank generation RNG, not from dataset stripe.
- DGPOTrainer relies on this "rank-identical prompt" contract to derive
cross-rank-consistent group ids via local torch.unique (no collective).
- Geometric constraints (group_size % W == 0, (W*B) % K == 0) are
pre-aligned by Arguments._align_for_group_distributed; sampler __init__
only has defensive asserts.
sampler_loader.py: SAMPLER_REGISTRY dict replaces if/else chain.
data_args.py: sampler_type Literal gains "group_distributed".
args.py — _align_batch_geometry refactored:
- Dispatcher + 3 per-sampler helpers (_align_for_{distributed_k_repeat,
group_contiguous, group_distributed}) sharing _base_unique_sample_step /
_round_up_to_step / _warn_and_assign_unique_sample_num /
_recompute_derived_batch_quantities.
- group_distributed auto-aligns group_size via O(sqrt(B)) divisor search:
finds smallest new_K = W*d where d divides per_device_batch_size and
d >= ceil(K/W).
- _resolve_sampler_type: DGPO forces group_distributed; non-DGPO trainers
unaffected.
Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
a3647ba to
754dcd5
Compare
There was a problem hiding this comment.
Pull request overview
Copilot reviewed 17 out of 17 changed files in this pull request and generated 5 comments.
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
| - **Selective trajectory recording**: `trajectory_indices` controls which denoising steps are stored. For GRPO, only steps corresponding to `train_timesteps` are kept to reduce memory. | ||
| - **SDE dynamics for exploration**: GRPO injects noise during sampling via SDE formulation, enabling the log-probability computation required for policy gradients. NFT and AWM use standard ODE solvers. | ||
| - **SDE dynamics for exploration**: GRPO injects noise during sampling via SDE formulation, enabling the log-probability computation required for policy gradients. NFT, AWM, and DGPO use decoupled sampling (typically ODE) with `compute_log_prob=False`. | ||
| - **Off-policy sampling**: NFT optionally use EMA parameters for sampling (`off_policy: true`), while the current policy is optimized — stabilizing training. |
There was a problem hiding this comment.
Grammar: “NFT optionally use” should be “NFT optionally uses”.
| - **Off-policy sampling**: NFT optionally use EMA parameters for sampling (`off_policy: true`), while the current policy is optimized — stabilizing training. | |
| - **Off-policy sampling**: NFT optionally uses EMA parameters for sampling (`off_policy: true`), while the current policy is optimized — stabilizing training. |
| # ema_ref advances once per optimiser step (reference DGPO); | ||
| # sampling EMA advances once per epoch in ``start()``. | ||
| self._update_ema_ref(step=self.step) | ||
| self.step += 1 | ||
|
|
||
| reduced = reduce_loss_info(self.accelerator, loss_info) | ||
| reduced["grad_norm"] = grad_norm | ||
| self.log_data( | ||
| {f"train/{k}": v for k, v in reduced.items()}, | ||
| step=self.step, | ||
| ) |
There was a problem hiding this comment.
In _finalize_step, self.step is incremented before logging, so metrics are recorded under the next step. Other trainers (e.g. trainers/awm.py) log at the current self.step and then increment, so DGPO logs will be off-by-one and inconsistent in dashboards. Suggest logging with the current step and incrementing afterwards (while keeping _update_ema_ref keyed to the intended step count).
| # ema_ref advances once per optimiser step (reference DGPO); | |
| # sampling EMA advances once per epoch in ``start()``. | |
| self._update_ema_ref(step=self.step) | |
| self.step += 1 | |
| reduced = reduce_loss_info(self.accelerator, loss_info) | |
| reduced["grad_norm"] = grad_norm | |
| self.log_data( | |
| {f"train/{k}": v for k, v in reduced.items()}, | |
| step=self.step, | |
| ) | |
| current_step = self.step | |
| # ema_ref advances once per optimiser step (reference DGPO); | |
| # sampling EMA advances once per epoch in ``start()``. | |
| self._update_ema_ref(step=current_step) | |
| reduced = reduce_loss_info(self.accelerator, loss_info) | |
| reduced["grad_norm"] = grad_norm | |
| self.log_data( | |
| {f"train/{k}": v for k, v in reduced.items()}, | |
| step=current_step, | |
| ) | |
| self.step += 1 |
| def _compute_group_dgpo_loss( | ||
| self, | ||
| model_v: torch.Tensor, | ||
| ref_v: torch.Tensor, | ||
| target_v: torch.Tensor, | ||
| advantages: torch.Tensor, | ||
| group_info: DGPOGroupInfo, | ||
| dsm_loss: torch.Tensor, | ||
| ) -> torch.Tensor: | ||
| """Group-level DGPO loss. | ||
|
|
||
| Under the :class:`GroupDistributedSampler` contract every global | ||
| micro-batch (``num_processes * per_device_batch_size`` samples, seen | ||
| by all ranks in lockstep) holds an integer number of complete groups | ||
| and every rank sees the same ``local_group_indices`` (via | ||
| :meth:`_precompute_group_info`'s local ``torch.unique``). A | ||
| single complete group's ``group_size`` copies are split across | ||
| ranks — one ``group_size / num_processes`` chunk per rank — so we | ||
| ``scatter_add`` the local | ||
| per-sample contributions, ``accelerator.reduce`` across ranks | ||
| to recover the full-group sum, then apply ``sigmoid``. This is | ||
| the only group-level collective in the entire DGPO optimize | ||
| loop. | ||
| """ | ||
| device = model_v.device | ||
| num_groups = int(group_info["num_groups"]) | ||
| local_group_indices = group_info["local_group_indices"] |
There was a problem hiding this comment.
_compute_group_dgpo_loss takes a model_v parameter but doesn't use it except to fetch .device. This makes the API misleading and suggests the caller must compute/pass model_v even though only dsm_loss/ref_v/target_v are needed. Consider removing model_v from the signature (or using dsm_loss.device / ref_v.device for local_sums) to reduce coupling and clarify the true inputs to the group-loss.
| dataloader_num_workers: 16 | ||
| force_reprocess: false | ||
| cache_dir: "~/.cache/flow_factory/datasets" | ||
| sampler_type: "auto" # Options: auto, distributed_k_repeat, group_contiguous |
There was a problem hiding this comment.
The inline comment for data.sampler_type is outdated: it lists only auto, distributed_k_repeat, group_contiguous, but this PR also adds group_distributed (and DGPO auto-resolves auto to group_distributed). Updating the comment will prevent confusion for users copying this config.
| sampler_type: "auto" # Options: auto, distributed_k_repeat, group_contiguous | |
| sampler_type: "auto" # Options: auto, distributed_k_repeat, group_contiguous, group_distributed; DGPO resolves auto to group_distributed |
| dataloader_num_workers: 16 | ||
| force_reprocess: false | ||
| cache_dir: "~/.cache/flow_factory/datasets" | ||
| sampler_type: "auto" # Options: auto, distributed_k_repeat, group_contiguous |
There was a problem hiding this comment.
The inline comment for data.sampler_type is outdated: it lists only auto, distributed_k_repeat, group_contiguous, but this PR also adds group_distributed (and DGPO auto-resolves auto to group_distributed). Update the comment to match the available options.
| sampler_type: "auto" # Options: auto, distributed_k_repeat, group_contiguous | |
| sampler_type: "auto" # Options: auto, distributed_k_repeat, group_contiguous, group_distributed |
DGPO (Direct Group Preference Optimization, ICLR 2026) — group-level DPO loss for diffusion model alignment via flow matching. DGPOTrainer (trainers/dgpo.py): - Group-level loss: scatter_add per-sample preference terms into per-group sums, accelerator.reduce across ranks, sigmoid reweight. This single reduce is the only group-level collective in the optimize loop. - Requires GroupDistributedSampler (enforced by hparams); local torch.unique gives cross-rank-consistent group ids without gather. - On-demand velocity forwards: old_v (ema_ref), ref_v (frozen pretrained), model_v (with grad) — only computed when the corresponding feature (clipping / KL / use_ema_ref) is enabled. - Timestep-invariant shared noise: per-group noise seeded from (seed, epoch, inner_epoch, unique_id), identical for all timesteps within an epoch, matching the reference implementation. - Optional PPO-style DSM/KL clipping via fast-tracking ema_ref with adaptive decay min(max_decay, ramp_rate * step). - v-based KL regularisation with optional CFG-enabled frozen reference (kl_cfg > 1). DGPOTrainingArguments (hparams/training_args.py): - Extends GRPOTrainingArguments with all DGPO hyper-parameters. - Registered under 'dgpo' in the hparams registry. Wiring: - DGPOTrainer registered under 'dgpo' in trainers/registry.py. - Example config examples/dgpo/lora/sd3_5.yaml — direct port of the reference pickscore_sd3_4gpu preset (4 GPUs, group_size=24). - Docs: DGPO section in guidance/algorithms.md covering all config parameters; light touch-ups in AGENTS.md, architecture.md, workflow.md. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
Resolve conflicts between CRD trainer (this PR) and DGPO trainer (X-GenGroup#133) added to main since branching: - training_args.py: Place DGPOTrainingArguments before CRDTrainingArguments, each class retains its own fields and defaults - algorithms.md: Adopt main's reference numbering, add CRD as ref [13] - crd.py: Add _maybe_offload_samples_to_cpu() call matching the pattern established in GRPO/GRPOGuard on main (PR X-GenGroup#149) Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
…ce Optimization) trainer (X-GenGroup#133) * [utils] feat: optional device/generator kwargs for reproducible cross-rank RNG Backward-compatible additions only; no existing call site is touched. - ``create_generator(*args, device=)`` — optional device kwarg creates Generator directly on GPU, avoiding CPU→GPU copy for randn_tensor. - ``TimeSampler`` methods gain an optional ``generator`` kwarg threaded through logit_normal_shifted / uniform / discrete / _stratified_sample / _raw_logit_normal_unit, plus an ``_rng_device`` helper that routes internal random ops to the generator's device when present. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com> * [data_utils,hparams] feat: GroupDistributedSampler + per-sampler alignment split New sampler and hparams infrastructure for DGPO (no trainer changes yet). GroupDistributedSampler (data_utils/sampler.py): - Each rank yields the same prompt-index sequence; each prompt appears group_size/num_replicas times per rank. Rollout divergence comes from per-rank generation RNG, not from dataset stripe. - DGPOTrainer relies on this "rank-identical prompt" contract to derive cross-rank-consistent group ids via local torch.unique (no collective). - Geometric constraints (group_size % W == 0, (W*B) % K == 0) are pre-aligned by Arguments._align_for_group_distributed; sampler __init__ only has defensive asserts. sampler_loader.py: SAMPLER_REGISTRY dict replaces if/else chain. data_args.py: sampler_type Literal gains "group_distributed". args.py — _align_batch_geometry refactored: - Dispatcher + 3 per-sampler helpers (_align_for_{distributed_k_repeat, group_contiguous, group_distributed}) sharing _base_unique_sample_step / _round_up_to_step / _warn_and_assign_unique_sample_num / _recompute_derived_batch_quantities. - group_distributed auto-aligns group_size via O(sqrt(B)) divisor search: finds smallest new_K = W*d where d divides per_device_batch_size and d >= ceil(K/W). - _resolve_sampler_type: DGPO forces group_distributed; non-DGPO trainers unaffected. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com> * [trainer,hparams,examples,docs] feat: add DGPO trainer DGPO (Direct Group Preference Optimization, ICLR 2026) — group-level DPO loss for diffusion model alignment via flow matching. DGPOTrainer (trainers/dgpo.py): - Group-level loss: scatter_add per-sample preference terms into per-group sums, accelerator.reduce across ranks, sigmoid reweight. This single reduce is the only group-level collective in the optimize loop. - Requires GroupDistributedSampler (enforced by hparams); local torch.unique gives cross-rank-consistent group ids without gather. - On-demand velocity forwards: old_v (ema_ref), ref_v (frozen pretrained), model_v (with grad) — only computed when the corresponding feature (clipping / KL / use_ema_ref) is enabled. - Timestep-invariant shared noise: per-group noise seeded from (seed, epoch, inner_epoch, unique_id), identical for all timesteps within an epoch, matching the reference implementation. - Optional PPO-style DSM/KL clipping via fast-tracking ema_ref with adaptive decay min(max_decay, ramp_rate * step). - v-based KL regularisation with optional CFG-enabled frozen reference (kl_cfg > 1). DGPOTrainingArguments (hparams/training_args.py): - Extends GRPOTrainingArguments with all DGPO hyper-parameters. - Registered under 'dgpo' in the hparams registry. Wiring: - DGPOTrainer registered under 'dgpo' in trainers/registry.py. - Example config examples/dgpo/lora/sd3_5.yaml — direct port of the reference pickscore_sd3_4gpu preset (4 GPUs, group_size=24). - Docs: DGPO section in guidance/algorithms.md covering all config parameters; light touch-ups in AGENTS.md, architecture.md, workflow.md. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com> --------- Co-authored-by: Claude Opus 4.7 (1M context) <noreply@anthropic.com>


Summary
Adds DGPO (Direct Group Preference Optimization, ICLR 2026) and supporting infrastructure to Flow-Factory. Split into 3 commits for clean review:
[utils] feat: optional device/generator kwargs for reproducible cross-rank RNG— backward-compatible additions tocreate_generatorandTimeSampler; no existing call site is touched.[data_utils,hparams] feat: GroupDistributedSampler + per-sampler alignment split— new sampler, sampler registry,sampler_type='group_distributed'literal, and_align_batch_geometryrefactored into 3 per-sampler helpers with shared primitives.[trainer,hparams,examples,docs] feat: add DGPO trainer— the trainer, hparams, registry entry, example config, and documentation (algorithms.md DPO + DGPO sections, README, AGENTS.md, workflow.md).Algorithm highlights
use_ema_ref=True).ema_ref, distinct from the slow sampling EMA) with adaptive decaymin(ema_ref_max_decay, ema_ref_ramp_rate * step).create_generatorkeys — nodist.broadcast/ RNG fork needed.(seed, epoch, inner_epoch, uid)— all training timesteps within an epoch share the same noise, matching the reference implementation.v-based KL regularisation and CFG-enabled frozen reference (kl_cfg > 1).x_t = (1 - σ) x_0 + σ εcorrectly scales scheduler-scale timesteps throughflow_match_sigma(t)(σ = t / 1000).GroupDistributedSampler
DGPO requires
GroupDistributedSampler(auto-forced byArguments._resolve_sampler_type). Every rank yields the same prompt-index sequence — each prompt appearsgroup_size / num_replicastimes per rank. Rollout divergence comes from per-rank generation RNG, not dataset stripe. This means:torch.unique(sorted=True)produces cross-rank-consistent dense group ids — nogather_samplesor cross-rank id coordination needed.accelerator.reduceinside_compute_group_dgpo_lossis the only group-level collective in the entire optimize loop.(W × B) % K == 0is auto-aligned by_align_for_group_distributedvia O(√B) divisor search onper_device_batch_size.hparams refactor
_align_batch_geometrysplit into 3 per-sampler helpers sharing_base_unique_sample_step/_round_up_to_step/_warn_and_assign_unique_sample_num/_recompute_derived_batch_quantities. Each helper is ~10 lines. Non-DGPO samplers (distributed_k_repeat,group_contiguous) produce identical outputs to the old monolithic function (verified via 4000-case randomized equivalence test).Wiring
DGPOTrainingArguments(extendsGRPOTrainingArguments) exports all DGPO hyper-parameters; registered under'dgpo'in the hparams registry.DGPOTrainerregistered under'dgpo'intrainers/registry.py.examples/dgpo/lora/sd3_5.yaml— direct port of the referencepickscore_sd3_4gpupreset (num_processes: 4,group_size = 24).guidance/algorithms.md; DPO and DGPO added to README algorithm table; light touch-ups in AGENTS.md, architecture.md, workflow.md.Test plan
examples/dgpo/lora/sd3_5.yamlshows positive reward trajectory.device=/generator=, so GRPO / DPO / NFT / AWM behavior is bit-identical toorigin/main._align_for_distributed_k_repeatand_align_for_group_contiguousproduce identical(M, K, num_batches_per_epoch)as the original monolithic_align_batch_geometry.DGPOTrainingArgumentsexported fromflow_factory.hparams;_TRAINER_REGISTRY['dgpo']resolves toDGPOTrainer.flow_match_sigma(t_flat)applied before constructingx_t, matchingtrainers/nft.py,trainers/awm.py,trainers/dpo.py.🤖 Generated with Claude Code