SimpleTuner
SimpleTuner copied to clipboard
crop_aspect_buckets needs both a portait and a landscape bucket
For _select_random_aspect() to return anything else than bucket 1.0 we need to decide what to do when we only specify only either landscape or portrait buckets. Not sure what's best way to tackle this.
I've tested this and it seem to work but I don't know the full extents of this problem and what needs to be done.
def _select_random_aspect(self):
"""
This method returns a random aspect bucket from a list of configured, allowed values.
Returns:
float: The selected aspect ratio.
"""
if not self.crop_aspect_buckets:
raise ValueError("Aspect buckets are not defined in the data backend config.")
if self.valid_metadata:
self.aspect_ratio = self.image_metadata["aspect_ratio"]
return self.aspect_ratio
if len(self.crop_aspect_buckets) > 0 and type(self.crop_aspect_buckets[0]) is dict:
has_portrait_buckets = any(bucket["aspect"] < 1.0 for bucket in self.crop_aspect_buckets)
has_landscape_buckets = any(bucket["aspect"] > 1.0 for bucket in self.crop_aspect_buckets)
logger.error(f"has_portrait_buckets: {has_portrait_buckets}, has_landscape_buckets: {has_landscape_buckets}")
# Instead of defaulting to 1.0, use whatever buckets are available
aspects = [bucket["aspect"] for bucket in self.crop_aspect_buckets]
weights = [bucket["weight"] for bucket in self.crop_aspect_buckets]
# Ensure that the weights add up to 1.0
total_weight = sum(weights)
if total_weight != 1.0:
raise ValueError("The weights of aspect buckets must add up to 1.")
selected_aspect = random.choices(aspects, weights)[0]
elif len(self.crop_aspect_buckets) > 0 and type(self.crop_aspect_buckets[0]) is float:
available_aspects = self._trim_aspect_bucket_list()
if len(available_aspects) == 0:
selected_aspect = 1.0
if should_log():
tqdm.write("[WARNING] Image dimensions do not fit into the configured aspect buckets. Using square crop.")
else:
selected_aspect = random.choice(available_aspects)
else:
raise ValueError(
"Aspect buckets must be a list of floats or dictionaries."
" If using a dictionary, it is expected to be in the format {'aspect': 1.0, 'weight': 0.5}."
" To provide multiple aspect ratios, use a list of dictionaries: [{'aspect': 1.0, 'weight': 0.5}, {'aspect': 1.5, 'weight': 0.5}]."
)
return selected_aspect
One might want to pick the closes bucket instead of random too.
Something like this. Had to update a few places to support "closest".
def _select_random_aspect(self):
"""
This method returns an aspect bucket based on the crop_aspect configuration.
If crop_aspect is "closest", it returns the closest aspect ratio.
If crop_aspect is "random", it returns a random aspect ratio based on weights.
Returns:
float: The selected aspect ratio.
"""
if not self.crop_aspect_buckets:
raise ValueError("Aspect buckets are not defined in the data backend config.")
if self.valid_metadata:
self.aspect_ratio = self.image_metadata["aspect_ratio"]
return self.aspect_ratio
# Handle 'preserve' crop_aspect mode by picking the closest aspect ratio
if self.crop_aspect == "closest":
closest_aspect = min(
self.crop_aspect_buckets,
key=lambda bucket: abs(
(bucket["aspect"] if isinstance(bucket, dict) else bucket) - self.aspect_ratio
)
)
closest_aspect_value = closest_aspect["aspect"] if isinstance(closest_aspect, dict) else closest_aspect
logger.info(f"Selected closest aspect: {closest_aspect_value} for aspect ratio: {self.aspect_ratio}")
return closest_aspect_value
# Handle 'random' crop_aspect mode by picking a random aspect ratio based on weights
if self.crop_aspect == "random":
if len(self.crop_aspect_buckets) > 0 and type(self.crop_aspect_buckets[0]) is dict:
has_portrait_buckets = any(bucket["aspect"] < 1.0 for bucket in self.crop_aspect_buckets)
has_landscape_buckets = any(bucket["aspect"] > 1.0 for bucket in self.crop_aspect_buckets)
logger.error(f"has_portrait_buckets: {has_portrait_buckets}, has_landscape_buckets: {has_landscape_buckets}")
# Instead of defaulting to 1.0, use whatever buckets are available
aspects = [bucket["aspect"] for bucket in self.crop_aspect_buckets]
weights = [bucket["weight"] for bucket in self.crop_aspect_buckets]
# Ensure that the weights add up to 1.0
total_weight = sum(weights)
if total_weight != 1.0:
raise ValueError("The weights of aspect buckets must add up to 1.")
selected_aspect = random.choices(aspects, weights)[0]
return selected_aspect
elif len(self.crop_aspect_buckets) > 0 and type(self.crop_aspect_buckets[0]) is float:
available_aspects = self._trim_aspect_bucket_list()
if len(available_aspects) == 0:
selected_aspect = 1.0
if should_log():
tqdm.write("[WARNING] Image dimensions do not fit into the configured aspect buckets. Using square crop.")
else:
selected_aspect = random.choice(available_aspects)
return selected_aspect
else:
raise ValueError(
"Aspect buckets must be a list of floats or dictionaries."
" If using a dictionary, it is expected to be in the format {'aspect': 1.0, 'weight': 0.5}."
" To provide multiple aspect ratios, use a list of dictionaries: [{'aspect': 1.0, 'weight': 0.5}, {'aspect': 1.5, 'weight': 0.5}]."
)
# Default to 1.0 if none of the conditions above match
return 1.0