Skip to content

[trainer,hparams,examples,docs] feat: add DGPO (Direct Group Preference Optimization) trainer#133

Merged
Jayce-Ping merged 3 commits into
mainfrom
DGPO
Apr 24, 2026
Merged

[trainer,hparams,examples,docs] feat: add DGPO (Direct Group Preference Optimization) trainer#133
Jayce-Ping merged 3 commits into
mainfrom
DGPO

Conversation

@Jayce-Ping

@Jayce-Ping Jayce-Ping commented Apr 24, 2026

Copy link
Copy Markdown
Collaborator

Summary

Adds DGPO (Direct Group Preference Optimization, ICLR 2026) and supporting infrastructure to Flow-Factory. Split into 3 commits for clean review:

  1. [utils] feat: optional device/generator kwargs for reproducible cross-rank RNG — backward-compatible additions to create_generator and TimeSampler; no existing call site is touched.
  2. [data_utils,hparams] feat: GroupDistributedSampler + per-sampler alignment split — new sampler, sampler registry, sampler_type='group_distributed' literal, and _align_batch_geometry refactored into 3 per-sampler helpers with shared primitives.
  3. [trainer,hparams,examples,docs] feat: add DGPO trainer — the trainer, hparams, registry entry, example config, and documentation (algorithms.md DPO + DGPO sections, README, AGENTS.md, workflow.md).

Algorithm highlights

  • Group-level DPO loss: scatter-add + sigmoid reweighting on per-group advantage-weighted DSM deltas against a frozen reference (or the fast EMA old policy when use_ema_ref=True).
  • Optional PPO-style clipping on DSM / KL loss, using a fast-tracking EMA (ema_ref, distinct from the slow sampling EMA) with adaptive decay min(ema_ref_max_decay, ema_ref_ramp_rate * step).
  • Cross-rank-deterministic shared timesteps / per-group shared noise via seeded create_generator keys — no dist.broadcast / RNG fork needed.
  • Timestep-invariant shared noise: per-group noise is seeded from (seed, epoch, inner_epoch, uid) — all training timesteps within an epoch share the same noise, matching the reference implementation.
  • v-based KL regularisation and CFG-enabled frozen reference (kl_cfg > 1).
  • Flow-matching interpolation x_t = (1 - σ) x_0 + σ ε correctly scales scheduler-scale timesteps through flow_match_sigma(t) (σ = t / 1000).

GroupDistributedSampler

DGPO requires GroupDistributedSampler (auto-forced by Arguments._resolve_sampler_type). Every rank yields the same prompt-index sequence — each prompt appears group_size / num_replicas times per rank. Rollout divergence comes from per-rank generation RNG, not dataset stripe. This means:

  • Local torch.unique(sorted=True) produces cross-rank-consistent dense group ids — no gather_samples or cross-rank id coordination needed.
  • The single accelerator.reduce inside _compute_group_dgpo_loss is the only group-level collective in the entire optimize loop.
  • Geometric constraint (W × B) % K == 0 is auto-aligned by _align_for_group_distributed via O(√B) divisor search on per_device_batch_size.

hparams refactor

_align_batch_geometry split into 3 per-sampler helpers sharing _base_unique_sample_step / _round_up_to_step / _warn_and_assign_unique_sample_num / _recompute_derived_batch_quantities. Each helper is ~10 lines. Non-DGPO samplers (distributed_k_repeat, group_contiguous) produce identical outputs to the old monolithic function (verified via 4000-case randomized equivalence test).

Wiring

  • DGPOTrainingArguments (extends GRPOTrainingArguments) exports all DGPO hyper-parameters; registered under 'dgpo' in the hparams registry.
  • DGPOTrainer registered under 'dgpo' in trainers/registry.py.
  • Example config examples/dgpo/lora/sd3_5.yaml — direct port of the reference pickscore_sd3_4gpu preset (num_processes: 4, group_size = 24).
  • Docs: DPO + DGPO sections in guidance/algorithms.md; DPO and DGPO added to README algorithm table; light touch-ups in AGENTS.md, architecture.md, workflow.md.

Test plan

  • Reward trends upward. Training on SD3.5-medium + PickScore with examples/dgpo/lora/sd3_5.yaml shows positive reward trajectory.
  • Sibling trainers unaffected. No existing call site passes device= / generator=, so GRPO / DPO / NFT / AWM behavior is bit-identical to origin/main.
  • hparams equivalence. 4000-case randomized test confirms _align_for_distributed_k_repeat and _align_for_group_contiguous produce identical (M, K, num_batches_per_epoch) as the original monolithic _align_batch_geometry.
  • Registry wiring. DGPOTrainingArguments exported from flow_factory.hparams; _TRAINER_REGISTRY['dgpo'] resolves to DGPOTrainer.
  • σ-scaling contract. flow_match_sigma(t_flat) applied before constructing x_t, matching trainers/nft.py, trainers/awm.py, trainers/dpo.py.
  • Reference implementation audit. Line-by-line comparison with Luo-Yihong/DGPO: all 12 dimensions verified equivalent (σ-scaling, loss, clipping, KL, EMA timing, grad accumulation, forward ordering, group info, shared noise, sampling switch).

🤖 Generated with Claude Code

Copilot AI review requested due to automatic review settings April 24, 2026 03:08

Copilot AI left a comment

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

Adds a new DGPO (Direct Group Preference Optimization) trainer to Flow-Factory, plus supporting RNG utilities and documentation/example wiring.

Changes:

  • Extend create_generator and TimeSampler APIs to support optional device/generator inputs for deterministic draws.
  • Add DGPOTrainer, DGPOTrainingArguments, and register them under the dgpo trainer/hparams registries.
  • Add a runnable DGPO example config and update guidance docs to document DGPO and list it in trainer/algorithm tables.

Reviewed changes

Copilot reviewed 11 out of 11 changed files in this pull request and generated 2 comments.

Show a summary per file
File Description
src/flow_factory/utils/noise_schedule.py Adds optional generator plumbing to timestep samplers for deterministic/cross-rank draws.
src/flow_factory/utils/base.py Extends create_generator with an optional device argument.
src/flow_factory/trainers/registry.py Registers the new dgpo trainer entry.
src/flow_factory/trainers/dgpo.py Implements the DGPO training loop, EMA-ref logic, shared RNG/noise, and group loss.
src/flow_factory/hparams/training_args.py Introduces DGPOTrainingArguments and registers them under dgpo.
src/flow_factory/hparams/__init__.py Exports DGPOTrainingArguments.
guidance/workflow.md Updates workflow tables/notes to include DGPO behavior.
guidance/algorithms.md Adds a DGPO section with hyperparameter documentation and references.
examples/dgpo/lora/sd3_5.yaml Provides an end-to-end DGPO example config (SD3.5 + PickScore).
AGENTS.md Adds DGPO to the listed algorithms and updates docs pointers.
.agents/knowledge/architecture.md Adds DGPO to the trainer registry table.

💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

Comment on lines +644 to +660
for inner_epoch in range(self.training_args.num_inner_epochs):
perm_gen = create_generator(self.training_args.seed, self.epoch, inner_epoch)
perm = torch.randperm(len(samples), generator=perm_gen)
shuffled = [samples[i] for i in perm]

sample_slices = [shuffled[i : i + bsz] for i in range(0, len(shuffled), bsz)]

# Pre-compute per-minibatch inputs that do not depend on the current
# policy: stacked batches, shared timesteps, group info. Training
# noise is generated per micro-step directly in this loop.
# Keeping a list of dicts (instead of mutating the stacked batch in
# place) avoids polluting the batch with training-only keys.
shared_timesteps = self._sample_shared_timesteps(inner_epoch) # (T,)
training_batches = self._build_training_batches(
sample_slices, shared_timesteps, inner_epoch
)

Copilot AI Apr 24, 2026

Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

DGPO’s group-level weighting assumes each unique_id’s full group contributes to global_sums, but optimize() shuffles/reshards individual samples and then forms sample_slices by per_device_batch_size. With the current samplers (e.g. DistributedKRepeatSampler shuffles all m*k indices each epoch), members of the same group will generally be split across multiple slices/steps, so _compute_group_dgpo_loss() will compute sigmoid weights from partial group sums and change the objective.

To make the group loss correct, batch/iterate in a way that keeps complete groups together (or do a two-pass per timestep: first accumulate per-sample deltas into per-group sums across all samples, then apply the resulting group weights when forming the final DSM loss).

Copilot uses AI. Check for mistakes.

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks — I audited this against the upstream reference (Luo-Yihong/DGPO: scripts/train_sd3_dgpo.py):

  • Upstream reshapes samples into train_batch_size=12-sized micro-batches at L1048-1056 with K = num_image_per_prompt = 24, i.e. the same K > bsz situation you flagged here.
  • Upstream's compute_group_dgpo_loss_allreduce (L162-193) runs the scatter-add + dist.all_reduce + sigmoid inside each micro-batch, so it too produces "partial-group sigmoid weights" per step.

So the per-micro-batch stochastic group weight is the intended objective, not a corruption — it's the standard SGD approximation of the group preference loss and matches the paper's released code byte-for-byte. The per-sample advantages carried into the sigmoid have already been normalized over the complete group in prepare_feedback, so the stochastic sigmoid reweight stays well-behaved. Validated end-to-end on SD3.5-medium + PickScore — reward trends upward.

Pushed an amend (fee9e36) that expands _compute_group_dgpo_loss's docstring to make the per-micro-batch stochastic semantics explicit and cites the reference implementation lines, so future readers don't have to re-derive it.

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks — I audited this against the upstream reference (Luo-Yihong/DGPO: scripts/train_sd3_dgpo.py):

  • Upstream reshapes samples into train_batch_size=12-sized micro-batches at L1048-1056 with K = num_image_per_prompt = 24, i.e. the same K > bsz situation you flagged here.
  • Upstream's compute_group_dgpo_loss_allreduce (L162-193) runs the scatter-add + dist.all_reduce + sigmoid inside each micro-batch, so it too produces "partial-group sigmoid weights" per step.

So the per-micro-batch stochastic group weight is the intended objective, not a corruption — it's the standard SGD approximation of the group preference loss and matches the paper's released code byte-for-byte. The per-sample advantages carried into the sigmoid have already been normalized over the complete group in prepare_feedback, so the stochastic sigmoid reweight stays well-behaved. Validated end-to-end on SD3.5-medium + PickScore — reward trends upward.

Pushed an amend (fee9e36) that expands _compute_group_dgpo_loss's docstring to make the per-micro-batch stochastic semantics explicit and cites the reference implementation lines, so future readers don't have to re-derive it.

Comment on lines +741 to +769
ema_ref_device: str = field(
default='cuda',
metadata={"help": "Device for old-policy EMA ref parameters ('cuda' or 'cpu')."},
)

# Timestep control
num_train_timesteps: int = field(
default=0,
metadata={"help": "Number of training timesteps per sample. 0 defaults to `int(num_inference_steps * (timestep_range[1] - timestep_range[0]))`."},
)
time_sampling_strategy: Literal['uniform', 'logit_normal', 'discrete', 'discrete_with_init', 'discrete_wo_init'] = field(
default='discrete',
metadata={"help": "Strategy for sampling training timesteps."},
)
time_shift: float = field(
default=3.0,
metadata={"help": "Shift parameter for logit-normal timestep sampling."},
)
timestep_range: Union[float, Tuple[float, float]] = field(
default=0.6,
metadata={"help": "Timestep range for discrete sampling. Float for [0, value], tuple for [start, end]."},
)

def __post_init__(self):
super().__post_init__()
self.timestep_range = _standardize_timestep_range(self.timestep_range)
if not self.num_train_timesteps or self.num_train_timesteps <= 0:
self.num_train_timesteps = max(1, int(self.num_inference_steps * (self.timestep_range[1] - self.timestep_range[0])))

Copilot AI Apr 24, 2026

Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ema_ref_device is typed as a free-form str, and DGPOTrainer currently treats anything other than 'cuda' as CPU. This can silently misconfigure runs (e.g. 'cuda:0' or typo) and hurt performance.

Consider validating in DGPOTrainingArguments.__post_init__ (or typing as Literal['cpu','cuda']) and raising a clear error for unsupported values, so misconfiguration fails fast.

Copilot uses AI. Check for mistakes.

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good catch — every other device field in training_args.py is already Literal["cpu", "cuda"] (ema_device L220, ref_param_device ×4 at L413/465/556/614). ema_ref_device: str was an inconsistency left over from an earlier iteration. Fixed in fee9e36 — now Literal["cpu", "cuda"]. The DGPOTrainer.__init__ consumer only branches on == "cuda" so there's no behavior change, just stronger static typing + CLI arg validation via dataclass.

@Jayce-Ping

Jayce-Ping commented Apr 24, 2026

Copy link
Copy Markdown
Collaborator Author

A quick verification of provided example config:

Clipboard_Screenshot_1777004337 Clipboard_Screenshot_1777036884

Jayce-Ping added a commit that referenced this pull request Apr 24, 2026
…wo-pass

Fix a stochastic-gradient artefact in the DGPO group loss: the previous
implementation's ``sigmoid(group_sums)`` was built from **partial group
sums** because (a) ``DistributedKRepeatSampler`` spreads a group's K
copies across ranks, (b) the in-rank ``randperm`` at the top of
``optimize()`` further splits a group across local micro-batches, and
(c) ``_compute_group_dgpo_loss`` was invoked per micro-batch. Every
optimizer step then saw a different, random subset of each group —
Copilot flagged the symptom on PR #133 L644-660.

The fix replaces the per-minibatch shuffle + per-minibatch sigmoid with
a single, uniform front-end:

1. ``_gather_and_reshard()`` at the start of every ``optimize()`` call
   runs ``gather_samples`` (reusing ``DPOTrainer``'s pattern at
   ``trainers/dpo.py:253``), buckets samples by ``unique_id``, permutes
   whole groups with a cross-rank-consistent seed, and reshards with
   ``groups[rank::world_size]``. After this every rank owns an integer
   number of **complete** groups, regardless of sampler topology.

2. With that invariant, two code paths cover every batch geometry:
   - ``_optimize_exact`` (Path A, ``B % K == 0``): each local micro-batch
     packs ``B/K`` whole groups, so ``_compute_group_dgpo_loss`` computes
     ``sigmoid`` on full-group sums in a single forward. No cross-rank
     reduce needed.
   - ``_optimize_two_pass`` (Path B, otherwise — the common case for
     ``B < K``): outer loop is ``for t_idx``, inner runs pass 1 (no-grad
     forward over every micro-batch, ``scatter_add`` into a rank-wide
     ``global_sums``, freeze ``sigmoid``) then pass 2 (grad forward +
     backward consuming the frozen weights via
     ``_apply_group_weighted_loss``). Pass 1 / pass 2 bit-identical noise
     is guaranteed by force-enabling ``use_shared_noise=True`` in this
     path.

Side effects:
- ``_precompute_group_info`` loses its ``distributed_k_repeat`` gather
  branch (resharded samples are always rank-local for group purposes).
- ``_compute_group_dgpo_loss`` loses the ``_needs_cross_rank_reduce``
  dispatch — the sum is always local after reshard.
- ``_group_on_same_rank`` property retained for parity with
  ``AdvantageProcessor``, but no longer consumed by the training loop.
- New rank-0 log line at ``__init__`` reports which path is active.
- ``guidance/algorithms.md#dgpo`` gains a "Group Completeness" subsection
  explaining the path selection and why ``B % K == 0`` unlocks the fast
  path.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
Jayce-Ping and others added 2 commits April 24, 2026 18:33
…-rank RNG

Backward-compatible additions only; no existing call site is touched.

- ``create_generator(*args, device=)`` — optional device kwarg creates
  Generator directly on GPU, avoiding CPU→GPU copy for randn_tensor.
- ``TimeSampler`` methods gain an optional ``generator`` kwarg threaded
  through logit_normal_shifted / uniform / discrete / _stratified_sample /
  _raw_logit_normal_unit, plus an ``_rng_device`` helper that routes
  internal random ops to the generator's device when present.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
…nment split

New sampler and hparams infrastructure for DGPO (no trainer changes yet).

GroupDistributedSampler (data_utils/sampler.py):
- Each rank yields the same prompt-index sequence; each prompt appears
  group_size/num_replicas times per rank. Rollout divergence comes from
  per-rank generation RNG, not from dataset stripe.
- DGPOTrainer relies on this "rank-identical prompt" contract to derive
  cross-rank-consistent group ids via local torch.unique (no collective).
- Geometric constraints (group_size % W == 0, (W*B) % K == 0) are
  pre-aligned by Arguments._align_for_group_distributed; sampler __init__
  only has defensive asserts.

sampler_loader.py: SAMPLER_REGISTRY dict replaces if/else chain.

data_args.py: sampler_type Literal gains "group_distributed".

args.py — _align_batch_geometry refactored:
- Dispatcher + 3 per-sampler helpers (_align_for_{distributed_k_repeat,
  group_contiguous, group_distributed}) sharing _base_unique_sample_step /
  _round_up_to_step / _warn_and_assign_unique_sample_num /
  _recompute_derived_batch_quantities.
- group_distributed auto-aligns group_size via O(sqrt(B)) divisor search:
  finds smallest new_K = W*d where d divides per_device_batch_size and
  d >= ceil(K/W).
- _resolve_sampler_type: DGPO forces group_distributed; non-DGPO trainers
  unaffected.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>

Copilot AI left a comment

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

Copilot reviewed 17 out of 17 changed files in this pull request and generated 5 comments.


💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

Comment thread guidance/workflow.md Outdated
- **Selective trajectory recording**: `trajectory_indices` controls which denoising steps are stored. For GRPO, only steps corresponding to `train_timesteps` are kept to reduce memory.
- **SDE dynamics for exploration**: GRPO injects noise during sampling via SDE formulation, enabling the log-probability computation required for policy gradients. NFT and AWM use standard ODE solvers.
- **SDE dynamics for exploration**: GRPO injects noise during sampling via SDE formulation, enabling the log-probability computation required for policy gradients. NFT, AWM, and DGPO use decoupled sampling (typically ODE) with `compute_log_prob=False`.
- **Off-policy sampling**: NFT optionally use EMA parameters for sampling (`off_policy: true`), while the current policy is optimized — stabilizing training.

Copilot AI Apr 24, 2026

Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Grammar: “NFT optionally use” should be “NFT optionally uses”.

Suggested change
- **Off-policy sampling**: NFT optionally use EMA parameters for sampling (`off_policy: true`), while the current policy is optimized — stabilizing training.
- **Off-policy sampling**: NFT optionally uses EMA parameters for sampling (`off_policy: true`), while the current policy is optimized — stabilizing training.

Copilot uses AI. Check for mistakes.
Comment on lines +1014 to +1024
# ema_ref advances once per optimiser step (reference DGPO);
# sampling EMA advances once per epoch in ``start()``.
self._update_ema_ref(step=self.step)
self.step += 1

reduced = reduce_loss_info(self.accelerator, loss_info)
reduced["grad_norm"] = grad_norm
self.log_data(
{f"train/{k}": v for k, v in reduced.items()},
step=self.step,
)

Copilot AI Apr 24, 2026

Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In _finalize_step, self.step is incremented before logging, so metrics are recorded under the next step. Other trainers (e.g. trainers/awm.py) log at the current self.step and then increment, so DGPO logs will be off-by-one and inconsistent in dashboards. Suggest logging with the current step and incrementing afterwards (while keeping _update_ema_ref keyed to the intended step count).

Suggested change
# ema_ref advances once per optimiser step (reference DGPO);
# sampling EMA advances once per epoch in ``start()``.
self._update_ema_ref(step=self.step)
self.step += 1
reduced = reduce_loss_info(self.accelerator, loss_info)
reduced["grad_norm"] = grad_norm
self.log_data(
{f"train/{k}": v for k, v in reduced.items()},
step=self.step,
)
current_step = self.step
# ema_ref advances once per optimiser step (reference DGPO);
# sampling EMA advances once per epoch in ``start()``.
self._update_ema_ref(step=current_step)
reduced = reduce_loss_info(self.accelerator, loss_info)
reduced["grad_norm"] = grad_norm
self.log_data(
{f"train/{k}": v for k, v in reduced.items()},
step=current_step,
)
self.step += 1

Copilot uses AI. Check for mistakes.
Comment on lines +490 to +516
def _compute_group_dgpo_loss(
self,
model_v: torch.Tensor,
ref_v: torch.Tensor,
target_v: torch.Tensor,
advantages: torch.Tensor,
group_info: DGPOGroupInfo,
dsm_loss: torch.Tensor,
) -> torch.Tensor:
"""Group-level DGPO loss.

Under the :class:`GroupDistributedSampler` contract every global
micro-batch (``num_processes * per_device_batch_size`` samples, seen
by all ranks in lockstep) holds an integer number of complete groups
and every rank sees the same ``local_group_indices`` (via
:meth:`_precompute_group_info`'s local ``torch.unique``). A
single complete group's ``group_size`` copies are split across
ranks — one ``group_size / num_processes`` chunk per rank — so we
``scatter_add`` the local
per-sample contributions, ``accelerator.reduce`` across ranks
to recover the full-group sum, then apply ``sigmoid``. This is
the only group-level collective in the entire DGPO optimize
loop.
"""
device = model_v.device
num_groups = int(group_info["num_groups"])
local_group_indices = group_info["local_group_indices"]

Copilot AI Apr 24, 2026

Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

_compute_group_dgpo_loss takes a model_v parameter but doesn't use it except to fetch .device. This makes the API misleading and suggests the caller must compute/pass model_v even though only dsm_loss/ref_v/target_v are needed. Consider removing model_v from the signature (or using dsm_loss.device / ref_v.device for local_sums) to reduce coupling and clarify the true inputs to the group-loss.

Copilot uses AI. Check for mistakes.
Comment thread examples/dgpo/lora/sd3_5.yaml Outdated
dataloader_num_workers: 16
force_reprocess: false
cache_dir: "~/.cache/flow_factory/datasets"
sampler_type: "auto" # Options: auto, distributed_k_repeat, group_contiguous

Copilot AI Apr 24, 2026

Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The inline comment for data.sampler_type is outdated: it lists only auto, distributed_k_repeat, group_contiguous, but this PR also adds group_distributed (and DGPO auto-resolves auto to group_distributed). Updating the comment will prevent confusion for users copying this config.

Suggested change
sampler_type: "auto" # Options: auto, distributed_k_repeat, group_contiguous
sampler_type: "auto" # Options: auto, distributed_k_repeat, group_contiguous, group_distributed; DGPO resolves auto to group_distributed

Copilot uses AI. Check for mistakes.
Comment thread examples/dgpo/lora/sd3_5_nocfg.yaml Outdated
dataloader_num_workers: 16
force_reprocess: false
cache_dir: "~/.cache/flow_factory/datasets"
sampler_type: "auto" # Options: auto, distributed_k_repeat, group_contiguous

Copilot AI Apr 24, 2026

Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The inline comment for data.sampler_type is outdated: it lists only auto, distributed_k_repeat, group_contiguous, but this PR also adds group_distributed (and DGPO auto-resolves auto to group_distributed). Update the comment to match the available options.

Suggested change
sampler_type: "auto" # Options: auto, distributed_k_repeat, group_contiguous
sampler_type: "auto" # Options: auto, distributed_k_repeat, group_contiguous, group_distributed

Copilot uses AI. Check for mistakes.
DGPO (Direct Group Preference Optimization, ICLR 2026) — group-level DPO
loss for diffusion model alignment via flow matching.

DGPOTrainer (trainers/dgpo.py):
- Group-level loss: scatter_add per-sample preference terms into per-group
  sums, accelerator.reduce across ranks, sigmoid reweight. This single
  reduce is the only group-level collective in the optimize loop.
- Requires GroupDistributedSampler (enforced by hparams); local
  torch.unique gives cross-rank-consistent group ids without gather.
- On-demand velocity forwards: old_v (ema_ref), ref_v (frozen pretrained),
  model_v (with grad) — only computed when the corresponding feature
  (clipping / KL / use_ema_ref) is enabled.
- Timestep-invariant shared noise: per-group noise seeded from
  (seed, epoch, inner_epoch, unique_id), identical for all timesteps
  within an epoch, matching the reference implementation.
- Optional PPO-style DSM/KL clipping via fast-tracking ema_ref with
  adaptive decay min(max_decay, ramp_rate * step).
- v-based KL regularisation with optional CFG-enabled frozen reference
  (kl_cfg > 1).

DGPOTrainingArguments (hparams/training_args.py):
- Extends GRPOTrainingArguments with all DGPO hyper-parameters.
- Registered under 'dgpo' in the hparams registry.

Wiring:
- DGPOTrainer registered under 'dgpo' in trainers/registry.py.
- Example config examples/dgpo/lora/sd3_5.yaml — direct port of the
  reference pickscore_sd3_4gpu preset (4 GPUs, group_size=24).
- Docs: DGPO section in guidance/algorithms.md covering all config
  parameters; light touch-ups in AGENTS.md, architecture.md, workflow.md.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
@Jayce-Ping Jayce-Ping merged commit 55a5268 into main Apr 24, 2026
Jayce-Ping added a commit to yuanzhi-zhu/Flow-Factory that referenced this pull request May 2, 2026
Resolve conflicts between CRD trainer (this PR) and DGPO trainer (X-GenGroup#133)
added to main since branching:

- training_args.py: Place DGPOTrainingArguments before CRDTrainingArguments,
  each class retains its own fields and defaults
- algorithms.md: Adopt main's reference numbering, add CRD as ref [13]
- crd.py: Add _maybe_offload_samples_to_cpu() call matching the pattern
  established in GRPO/GRPOGuard on main (PR X-GenGroup#149)

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
Jayce-Ping added a commit to Jayce-Ping/Flow-Factory-Private that referenced this pull request Jul 2, 2026
…ce Optimization) trainer (X-GenGroup#133)

* [utils] feat: optional device/generator kwargs for reproducible cross-rank RNG

Backward-compatible additions only; no existing call site is touched.

- ``create_generator(*args, device=)`` — optional device kwarg creates
  Generator directly on GPU, avoiding CPU→GPU copy for randn_tensor.
- ``TimeSampler`` methods gain an optional ``generator`` kwarg threaded
  through logit_normal_shifted / uniform / discrete / _stratified_sample /
  _raw_logit_normal_unit, plus an ``_rng_device`` helper that routes
  internal random ops to the generator's device when present.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>

* [data_utils,hparams] feat: GroupDistributedSampler + per-sampler alignment split

New sampler and hparams infrastructure for DGPO (no trainer changes yet).

GroupDistributedSampler (data_utils/sampler.py):
- Each rank yields the same prompt-index sequence; each prompt appears
  group_size/num_replicas times per rank. Rollout divergence comes from
  per-rank generation RNG, not from dataset stripe.
- DGPOTrainer relies on this "rank-identical prompt" contract to derive
  cross-rank-consistent group ids via local torch.unique (no collective).
- Geometric constraints (group_size % W == 0, (W*B) % K == 0) are
  pre-aligned by Arguments._align_for_group_distributed; sampler __init__
  only has defensive asserts.

sampler_loader.py: SAMPLER_REGISTRY dict replaces if/else chain.

data_args.py: sampler_type Literal gains "group_distributed".

args.py — _align_batch_geometry refactored:
- Dispatcher + 3 per-sampler helpers (_align_for_{distributed_k_repeat,
  group_contiguous, group_distributed}) sharing _base_unique_sample_step /
  _round_up_to_step / _warn_and_assign_unique_sample_num /
  _recompute_derived_batch_quantities.
- group_distributed auto-aligns group_size via O(sqrt(B)) divisor search:
  finds smallest new_K = W*d where d divides per_device_batch_size and
  d >= ceil(K/W).
- _resolve_sampler_type: DGPO forces group_distributed; non-DGPO trainers
  unaffected.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>

* [trainer,hparams,examples,docs] feat: add DGPO trainer

DGPO (Direct Group Preference Optimization, ICLR 2026) — group-level DPO
loss for diffusion model alignment via flow matching.

DGPOTrainer (trainers/dgpo.py):
- Group-level loss: scatter_add per-sample preference terms into per-group
  sums, accelerator.reduce across ranks, sigmoid reweight. This single
  reduce is the only group-level collective in the optimize loop.
- Requires GroupDistributedSampler (enforced by hparams); local
  torch.unique gives cross-rank-consistent group ids without gather.
- On-demand velocity forwards: old_v (ema_ref), ref_v (frozen pretrained),
  model_v (with grad) — only computed when the corresponding feature
  (clipping / KL / use_ema_ref) is enabled.
- Timestep-invariant shared noise: per-group noise seeded from
  (seed, epoch, inner_epoch, unique_id), identical for all timesteps
  within an epoch, matching the reference implementation.
- Optional PPO-style DSM/KL clipping via fast-tracking ema_ref with
  adaptive decay min(max_decay, ramp_rate * step).
- v-based KL regularisation with optional CFG-enabled frozen reference
  (kl_cfg > 1).

DGPOTrainingArguments (hparams/training_args.py):
- Extends GRPOTrainingArguments with all DGPO hyper-parameters.
- Registered under 'dgpo' in the hparams registry.

Wiring:
- DGPOTrainer registered under 'dgpo' in trainers/registry.py.
- Example config examples/dgpo/lora/sd3_5.yaml — direct port of the
  reference pickscore_sd3_4gpu preset (4 GPUs, group_size=24).
- Docs: DGPO section in guidance/algorithms.md covering all config
  parameters; light touch-ups in AGENTS.md, architecture.md, workflow.md.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>

---------

Co-authored-by: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants