From 87c8c0c7a8b6be66e87911877946c6e2f07b5eab Mon Sep 17 00:00:00 2001 From: Mary Hipp Date: Fri, 25 Jul 2025 15:30:52 -0400 Subject: [PATCH 1/3] update flux image transform to resize properly --- .../data/transforms/flux_image_transform.py | 36 +++++++++++++++---- 1 file changed, 29 insertions(+), 7 deletions(-) diff --git a/src/invoke_training/_shared/data/transforms/flux_image_transform.py b/src/invoke_training/_shared/data/transforms/flux_image_transform.py index b840f43a..669f4b28 100644 --- a/src/invoke_training/_shared/data/transforms/flux_image_transform.py +++ b/src/invoke_training/_shared/data/transforms/flux_image_transform.py @@ -1,6 +1,7 @@ import typing from torchvision import transforms +from torchvision.transforms.functional import crop from invoke_training._shared.data.utils.aspect_ratio_bucket_manager import AspectRatioBucketManager, Resolution from invoke_training._shared.data.utils.resize import resize_to_cover @@ -43,32 +44,53 @@ def __call__(self, data: typing.Dict[str, typing.Any]) -> typing.Dict[str, typin for field_name in self.image_field_names: image_fields[field_name] = data[field_name] + # Get the first image to determine original size and resolution + first_image = next(iter(image_fields.values())) + original_size_hw = (first_image.height, first_image.width) + for field_name, image in image_fields.items(): # Determine the target image resolution. if self.resolution is not None: resolution = self.resolution resolution_obj = Resolution(resolution, resolution) else: - original_size_hw = (image.height, image.width) resolution_obj = self.aspect_ratio_bucket_manager.get_aspect_ratio_bucket( Resolution.parse(original_size_hw) ) image = resize_to_cover(image, resolution_obj) + + # Apply cropping and record top left crop position if self.center_crop: - image = transforms.CenterCrop(resolution)(image) + top_left_y = max(0, (image.height - resolution_obj.height) // 2) + top_left_x = max(0, (image.width - resolution_obj.width) // 2) + image = transforms.CenterCrop(resolution_obj.to_tuple())(image) else: - image = transforms.RandomCrop(resolution)(image) - - image = transforms.ToTensor()(image) + crop_transform = transforms.RandomCrop(resolution_obj.to_tuple()) + top_left_y, top_left_x, h, w = crop_transform.get_params(image, resolution_obj.to_tuple()) + image = crop(image, top_left_y, top_left_x, resolution_obj.height, resolution_obj.width) + # Apply random flip and update top left crop position accordingly if self.random_flip: - image = transforms.RandomHorizontalFlip(p=0.5)(image) - image_fields[field_name] = image + # TODO: Use a seed for repeatable results + import random + if random.random() < 0.5: + top_left_x = original_size_hw[1] - image.width - top_left_x + image = transforms.RandomHorizontalFlip(p=1.0)(image) + + image = transforms.ToTensor()(image) if field_name in self.fields_to_normalize_to_range_minus_one_to_one: image_fields[field_name] = transforms.Normalize([0.5], [0.5])(image) + else: + image_fields[field_name] = image + # Store the processed images and metadata for field_name, image in image_fields.items(): data[field_name] = image + + # Add metadata fields expected by VAE caching + data["original_size_hw"] = original_size_hw + data["crop_top_left_yx"] = (top_left_y, top_left_x) + return data From c5d19344f83af0e163622ca1997eb8f893e10db5 Mon Sep 17 00:00:00 2001 From: Mary Hipp Date: Mon, 28 Jul 2025 14:00:32 -0400 Subject: [PATCH 2/3] ruff fix --- .../_shared/data/transforms/flux_image_transform.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/invoke_training/_shared/data/transforms/flux_image_transform.py b/src/invoke_training/_shared/data/transforms/flux_image_transform.py index 669f4b28..8a77e2f4 100644 --- a/src/invoke_training/_shared/data/transforms/flux_image_transform.py +++ b/src/invoke_training/_shared/data/transforms/flux_image_transform.py @@ -59,7 +59,7 @@ def __call__(self, data: typing.Dict[str, typing.Any]) -> typing.Dict[str, typin ) image = resize_to_cover(image, resolution_obj) - + # Apply cropping and record top left crop position if self.center_crop: top_left_y = max(0, (image.height - resolution_obj.height) // 2) @@ -88,9 +88,9 @@ def __call__(self, data: typing.Dict[str, typing.Any]) -> typing.Dict[str, typin # Store the processed images and metadata for field_name, image in image_fields.items(): data[field_name] = image - + # Add metadata fields expected by VAE caching data["original_size_hw"] = original_size_hw data["crop_top_left_yx"] = (top_left_y, top_left_x) - + return data From 01217077c05e007a1d48a1dbbf56864cbef025e6 Mon Sep 17 00:00:00 2001 From: Mary Hipp Date: Mon, 28 Jul 2025 14:36:53 -0400 Subject: [PATCH 3/3] ruff format --- .../_shared/data/transforms/flux_image_transform.py | 1 + 1 file changed, 1 insertion(+) diff --git a/src/invoke_training/_shared/data/transforms/flux_image_transform.py b/src/invoke_training/_shared/data/transforms/flux_image_transform.py index 8a77e2f4..0c945895 100644 --- a/src/invoke_training/_shared/data/transforms/flux_image_transform.py +++ b/src/invoke_training/_shared/data/transforms/flux_image_transform.py @@ -74,6 +74,7 @@ def __call__(self, data: typing.Dict[str, typing.Any]) -> typing.Dict[str, typin if self.random_flip: # TODO: Use a seed for repeatable results import random + if random.random() < 0.5: top_left_x = original_size_hw[1] - image.width - top_left_x image = transforms.RandomHorizontalFlip(p=1.0)(image)