SimpleTuner icon indicating copy to clipboard operation
SimpleTuner copied to clipboard

crop_aspect_buckets needs both a portait and a landscape bucket

Open twri opened this issue 1 year ago • 2 comments

For _select_random_aspect() to return anything else than bucket 1.0 we need to decide what to do when we only specify only either landscape or portrait buckets. Not sure what's best way to tackle this.

I've tested this and it seem to work but I don't know the full extents of this problem and what needs to be done.

    def _select_random_aspect(self):
        """
        This method returns a random aspect bucket from a list of configured, allowed values.

        Returns:
            float: The selected aspect ratio.
        """
        if not self.crop_aspect_buckets:
            raise ValueError("Aspect buckets are not defined in the data backend config.")

        if self.valid_metadata:
            self.aspect_ratio = self.image_metadata["aspect_ratio"]
            return self.aspect_ratio

        if len(self.crop_aspect_buckets) > 0 and type(self.crop_aspect_buckets[0]) is dict:
            has_portrait_buckets = any(bucket["aspect"] < 1.0 for bucket in self.crop_aspect_buckets)
            has_landscape_buckets = any(bucket["aspect"] > 1.0 for bucket in self.crop_aspect_buckets)
            logger.error(f"has_portrait_buckets: {has_portrait_buckets}, has_landscape_buckets: {has_landscape_buckets}")

            # Instead of defaulting to 1.0, use whatever buckets are available
            aspects = [bucket["aspect"] for bucket in self.crop_aspect_buckets]
            weights = [bucket["weight"] for bucket in self.crop_aspect_buckets]

            # Ensure that the weights add up to 1.0
            total_weight = sum(weights)
            if total_weight != 1.0:
                raise ValueError("The weights of aspect buckets must add up to 1.")

            selected_aspect = random.choices(aspects, weights)[0]

        elif len(self.crop_aspect_buckets) > 0 and type(self.crop_aspect_buckets[0]) is float:
            available_aspects = self._trim_aspect_bucket_list()
            if len(available_aspects) == 0:
                selected_aspect = 1.0
                if should_log():
                    tqdm.write("[WARNING] Image dimensions do not fit into the configured aspect buckets. Using square crop.")
            else:
                selected_aspect = random.choice(available_aspects)
        else:
            raise ValueError(
                "Aspect buckets must be a list of floats or dictionaries."
                " If using a dictionary, it is expected to be in the format {'aspect': 1.0, 'weight': 0.5}."
                " To provide multiple aspect ratios, use a list of dictionaries: [{'aspect': 1.0, 'weight': 0.5}, {'aspect': 1.5, 'weight': 0.5}]."
            )

        return selected_aspect

twri avatar Aug 11 '24 19:08 twri

One might want to pick the closes bucket instead of random too.

twri avatar Aug 11 '24 19:08 twri

Something like this. Had to update a few places to support "closest".

    def _select_random_aspect(self):
        """
        This method returns an aspect bucket based on the crop_aspect configuration.
        If crop_aspect is "closest", it returns the closest aspect ratio.
        If crop_aspect is "random", it returns a random aspect ratio based on weights.

        Returns:
            float: The selected aspect ratio.
        """
        if not self.crop_aspect_buckets:
            raise ValueError("Aspect buckets are not defined in the data backend config.")

        if self.valid_metadata:
            self.aspect_ratio = self.image_metadata["aspect_ratio"]
            return self.aspect_ratio

        # Handle 'preserve' crop_aspect mode by picking the closest aspect ratio
        if self.crop_aspect == "closest":
            closest_aspect = min(
                self.crop_aspect_buckets,
                key=lambda bucket: abs(
                    (bucket["aspect"] if isinstance(bucket, dict) else bucket) - self.aspect_ratio
                )
            )
            closest_aspect_value = closest_aspect["aspect"] if isinstance(closest_aspect, dict) else closest_aspect
            logger.info(f"Selected closest aspect: {closest_aspect_value} for aspect ratio: {self.aspect_ratio}")
            return closest_aspect_value

        # Handle 'random' crop_aspect mode by picking a random aspect ratio based on weights
        if self.crop_aspect == "random":
            if len(self.crop_aspect_buckets) > 0 and type(self.crop_aspect_buckets[0]) is dict:
                has_portrait_buckets = any(bucket["aspect"] < 1.0 for bucket in self.crop_aspect_buckets)
                has_landscape_buckets = any(bucket["aspect"] > 1.0 for bucket in self.crop_aspect_buckets)
                logger.error(f"has_portrait_buckets: {has_portrait_buckets}, has_landscape_buckets: {has_landscape_buckets}")

                # Instead of defaulting to 1.0, use whatever buckets are available
                aspects = [bucket["aspect"] for bucket in self.crop_aspect_buckets]
                weights = [bucket["weight"] for bucket in self.crop_aspect_buckets]

                # Ensure that the weights add up to 1.0
                total_weight = sum(weights)
                if total_weight != 1.0:
                    raise ValueError("The weights of aspect buckets must add up to 1.")

                selected_aspect = random.choices(aspects, weights)[0]
                return selected_aspect

            elif len(self.crop_aspect_buckets) > 0 and type(self.crop_aspect_buckets[0]) is float:
                available_aspects = self._trim_aspect_bucket_list()
                if len(available_aspects) == 0:
                    selected_aspect = 1.0
                    if should_log():
                        tqdm.write("[WARNING] Image dimensions do not fit into the configured aspect buckets. Using square crop.")
                else:
                    selected_aspect = random.choice(available_aspects)
                return selected_aspect

            else:
                raise ValueError(
                    "Aspect buckets must be a list of floats or dictionaries."
                    " If using a dictionary, it is expected to be in the format {'aspect': 1.0, 'weight': 0.5}."
                    " To provide multiple aspect ratios, use a list of dictionaries: [{'aspect': 1.0, 'weight': 0.5}, {'aspect': 1.5, 'weight': 0.5}]."
                )

        # Default to 1.0 if none of the conditions above match
        return 1.0

twri avatar Aug 11 '24 20:08 twri