Skip to content

Commit f36fa63

Browse files
committed
Fix: SES
1 parent 1d07b66 commit f36fa63

6 files changed

Lines changed: 15 additions & 23 deletions

File tree

diffsynth/utils/inference_time_scaling/ses.py

Lines changed: 6 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -68,10 +68,9 @@ def _load_model(self):
6868

6969
def get_score(self, image_pil, text_prompt):
7070
try:
71-
with torch.no_grad():
71+
with torch.no_grad(), torch.autocast(device_type="cuda", dtype=torch.float32):
7272
if self.reward_name == "pick":
7373
inputs = self.processor(text=[text_prompt], images=[image_pil], return_tensors="pt", padding="max_length", truncation=True, max_length=77).to(self.device)
74-
inputs['pixel_values'] = inputs['pixel_values'].to(self.dtype)
7574
outputs = self.model(**inputs)
7675
return outputs.logits_per_image[0, 0].item()
7776

@@ -99,14 +98,13 @@ def run_ses_cem(
9998
popsize=10,
10099
k_elites=5,
101100
wavelet_name="db1",
102-
dwt_level=5,
103-
lambda_prior=1e-3
101+
dwt_level=4,
104102
):
105103
latent_h, latent_w = base_latents.shape[-2], base_latents.shape[-1]
106104
c_low_init, c_high_fixed_batch = split_dwt(base_latents, wavelet_name, dwt_level)
107105
c_high_fixed = c_high_fixed_batch[0]
108-
c_low_shape = c_low_init.shape[1:]
109-
mu = c_low_init.view(-1).cpu()
106+
c_low_shape = c_low_init.shape[1:]
107+
mu = torch.zeros_like(c_low_init.view(-1).cpu())
110108
sigma_sq = torch.ones_like(mu) * 1.0
111109

112110
best_overall = {"fitness": -float('inf'), "score": -float('inf'), "c_low": c_low_init[0]}
@@ -135,16 +133,12 @@ def run_ses_cem(
135133
img = pipeline_callback(z_recon)
136134

137135
score = scorer.get_score(img, prompt)
138-
penalty = lambda_prior * (torch.norm(c_low_sample.float())**2).item()
139-
fitness = score - penalty
140-
141136
res = {
142-
"fitness": fitness,
143137
"score": score,
144138
"c_low": c_low_sample.cpu()
145139
}
146140
batch_results.append(res)
147-
if fitness > best_overall['fitness']:
141+
if score > best_overall['score']:
148142
best_overall = res
149143

150144
eval_count += 1
@@ -156,7 +150,7 @@ def run_ses_cem(
156150

157151
if not batch_results: break
158152
elite_db.extend(batch_results)
159-
elite_db.sort(key=lambda x: x['fitness'], reverse=True)
153+
elite_db.sort(key=lambda x: x['score'], reverse=True)
160154
elite_db = elite_db[:k_elites]
161155
elites_flat = torch.stack([x['c_low'].view(-1) for x in elite_db])
162156
mu_new = torch.mean(elites_flat, dim=0)

examples/flux/model_inference/FLUX.1-dev-SES.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
],
1313
)
1414

15-
prompt = "A solo girl with silver wavy hair and blue eyes, wearing a blue dress, underwater, air bubbles, floating hair."
15+
prompt = "A magical forest where trees are made of candy"
1616
negative_prompt = "nsfw, low quality"
1717

1818
image = pipe(

examples/flux2/model_inference/FLUX.2-dev-SES.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@
2121
],
2222
tokenizer_config=ModelConfig(model_id="black-forest-labs/FLUX.2-dev", origin_file_pattern="tokenizer/"),
2323
)
24-
prompt = "A hermit crab using a soda can as its shell on the beach. The can has the text 'BFL Diffusers' on it."
24+
prompt = "A magical forest where trees are made of candy"
2525

2626
image = pipe(
2727
prompt,
@@ -31,7 +31,7 @@
3131
enable_ses=True,
3232
ses_reward_model="pick",
3333
ses_eval_budget=20,
34-
ses_inference_steps=10
34+
ses_inference_steps=20
3535
)
3636

3737
image.save("image_FLUX.2-dev_ses.jpg")

examples/qwen_image/model_inference/Qwen-Image-SES.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -12,15 +12,15 @@
1212
tokenizer_config=ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="tokenizer/"),
1313
)
1414

15-
prompt = "水下少女,身穿蓝裙,周围有气泡。"
15+
prompt = "一把精致的汉服折扇,上面绘有山水"
1616

1717
image = pipe(
1818
prompt,
1919
seed=0,
2020
num_inference_steps=40,
2121
enable_ses=True,
2222
ses_reward_model="pick",
23-
ses_eval_budget=20,
23+
ses_eval_budget=30,
2424
ses_inference_steps=10
2525
)
2626

examples/z_image/model_inference/Z-Image-SES.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
],
1313
tokenizer_config=ModelConfig(model_id="Tongyi-MAI/Z-Image-Turbo", origin_file_pattern="tokenizer/"),
1414
)
15-
prompt = "Chinese woman in red Hanfu holding a fan, with a bright yellow neon lightning bolt floating above her palm. Soft-lit outdoor night background, silhouetted tiered pagoda (西安大雁塔), blurred colorful distant lights."
15+
prompt = "A cyberpunk girl with neon glowing eyes"
1616

1717
image = pipe(
1818
prompt=prompt,
@@ -25,6 +25,4 @@
2525
ses_eval_budget=20,
2626
ses_inference_steps=10
2727
)
28-
image.save("image_Z-Image_ses.jpg")
29-
30-
28+
image.save("image_Z-Image_ses.jpg")

examples/z_image/model_inference/Z-Image-Turbo-SES.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -12,15 +12,15 @@
1212
tokenizer_config=ModelConfig(model_id="Tongyi-MAI/Z-Image-Turbo", origin_file_pattern="tokenizer/"),
1313
)
1414

15-
prompt = "Chinese woman in red Hanfu holding a fan, with a bright yellow neon lightning bolt floating above her palm. Soft-lit outdoor night background, silhouetted tiered pagoda (西安大雁塔), blurred colorful distant lights."
15+
prompt = "A cyberpunk girl with neon glowing eyes"
1616

1717
image = pipe(
1818
prompt=prompt,
1919
seed=42,
2020
rand_device="cuda",
2121
enable_ses=True,
2222
ses_reward_model="pick",
23-
ses_eval_budget=50,
23+
ses_eval_budget=30,
2424
ses_inference_steps=8
2525
)
2626
image.save("image_Z-Image-Turbo_ses.jpg")

0 commit comments

Comments
 (0)