diffusers icon indicating copy to clipboard operation
diffusers copied to clipboard

SD3 ControlNet Script (and others?): dataset preprocessing cache depends on unrelated arguments

Open kentdan3msu opened this issue 9 months ago • 9 comments

Describe the bug

When using the SD3 ControlNet training script, the training dataset embeddings are precomputed and the results are given a fingerprint based on the input script arguments, which will cause subsequent runs to use the cached preprocessed dataset instead of recomputing the embeddings, which in my experience takes a while. However, arguments that are completely unrelated to the dataset affect this hash, so any minor changes to runtime arguments can trigger a full dataset remap.

The code in question: https://github.com/huggingface/diffusers/blob/ed4efbd63d0f6b271894bc404b12f512d6b764e5/examples/controlnet/train_controlnet_sd3.py#L1167-L1183

Ideally, the hash should only depend on arguments directly related to the dataset and the embeddings computation function, namely --dataset_config, --dataset_config_name, --pretrained_model_or_path, --variant, and --revision (I could be wrong and other arguments may or may not affect the embeddings, someone please validate).

This issue might affect other example scripts - I haven't looked too deeply at other scripts, but if this technique is used elsewhere, it could be causing similar training startup delays, especially for people using the same dataset over multiple training attempts.

Reproduction

Step 1: Start training an SD3 ControlNet model (specific dataset shouldn't matter, and you can exit the script after the map() is complete and training begins)

python3 examples/controlnet/train_controlnet_sd3.py --pretrained_model_name_or_path=/path/to/your/stable-diffusion-3-medium-diffusers --output_dir=/path/to/output --dataset_name=fusing/fill50k --resolution=1024 --learning_rate=1e-5 --train_batch_size=2 --dataset_preprocess_batch_size=500

Step 2: relaunch with the same arguments. Observe that the map() is skipped and training restarts fairly quickly.

Step 3: relaunch with --train_batch_size reduced to 1. Observe that the map() is restarted.

Proposed fix

For SD3 ControlNet script: either create a copy of the arguments and remove arguments that do not affect the input arguments, or create a new argparse.Namespace() and copy the arguments that matter.

Option 1 would look something like:

import copy
args_copy = copy.deepcopy(args)
for unwanted_arg in ['output_dir', 'train_batch_size', 'dataset_preprocess_batch_size',
                     'gradient_accumulation_steps', 'gradient_checkpointing', 'learning_rate',
                     'max_train_steps', 'checkpointing_steps', 'lr_num_cycles',
                     'validation_prompt', 'validation_image', 'validation_steps']:
    if hasattr(args_copy, unwanted_arg):
        delattr(args_copy, unwanted_arg)
new_fingerprint = Hasher.hash(args_copy)

Option 2 would look something like:

args_copy = argparse.Namespace()
for dataset_arg in ['dataset_config_name', 'pretrained_model_name_or_path', 'variant', 'revision']:
    setattr(args_copy, dataset_arg, getattr(args, dataset_arg))
new_fingerprint = Hasher.hash(args_copy)

Ideally, if the dataset is loaded from a local config, the contents of the config should be hashed instead of the filename itself, but fixing that would be a bit more complex (and might be beyond the scope of the example script).

System Info

  • 🤗 Diffusers version: 0.34.0.dev0
  • Platform: Linux-5.15.0-91-generic-x86_64-with-glibc2.35
  • Running on Google Colab?: No
  • Python version: 3.11.11
  • PyTorch version (GPU?): 2.6.0+cu124 (True)
  • Flax version (CPU?/GPU?/TPU?): not installed (NA)
  • Jax version: not installed
  • JaxLib version: not installed
  • Huggingface_hub version: 0.29.3
  • Transformers version: 4.50.0
  • Accelerate version: 1.5.2
  • PEFT version: 0.11.1
  • Bitsandbytes version: not installed
  • Safetensors version: 0.5.3
  • xFormers version: not installed
  • Accelerator: NVIDIA RTX A6000, 49140 MiB
  • Using GPU in script?: Yes
  • Using distributed or parallel set-up in script?: Tested with both single instance and accelerate with 2+ GPUs

Who can help?

@sayakpaul

kentdan3msu avatar May 05 '25 18:05 kentdan3msu

Good point and thanks for the detailed issue. I would be in favor of adding a utility function to training_utils.py with option 1. Option 1 reads more explicit to me hence the preference.

Would you like to open a PR for this? This way, your contributions stay in the repo and IMO, better respected.

sayakpaul avatar May 06 '25 03:05 sayakpaul

Sure, I can open a PR. I'll try to have something ready in about a week.

Just so I'm on the right track, would something like this work for the interface at the training script level?

# Somewhere near the top of each training script
DATASET_AGNOSTIC_ARGS=['output_dir', 'train_batch_size', 'dataset_preprocess_batch_size',
                             'gradient_accumulation_steps', 'gradient_checkpointing', 'learning_rate',
                             'max_train_steps', 'checkpointing_steps', 'add_lora', 'lr_num_cycles',
                             'validation_prompt', 'validation_image', 'validation_steps', 'dataset_test_config_name',
                             'resume_from_checkpoint', 'upcast_vae']
# Before the map()
dataset_hash = Hasher.hash(diffusers.training_utils.remove_arguments(args, DATASET_AGNOSTIC_ARGS))

I can write remove_arguments() so that it can accept either a list of strings, or just strings themselves, i.e.:

new_args = remove_arguments(args, 'argument_1', 'argument_2')
# or
new_args =remove_arguments(args, ['argument_1', 'argument_2'])

Depending on how complete this should be, a corresponding add_arguments() function could also be included.

kentdan3msu avatar May 06 '25 19:05 kentdan3msu

I think the following should work:

# Somewhere near the top of each training script
DATASET_AGNOSTIC_ARGS=['output_dir', 'train_batch_size', 'dataset_preprocess_batch_size',
                             'gradient_accumulation_steps', 'gradient_checkpointing', 'learning_rate',
                             'max_train_steps', 'checkpointing_steps', 'add_lora', 'lr_num_cycles',
                             'validation_prompt', 'validation_image', 'validation_steps', 'dataset_test_config_name',
                             'resume_from_checkpoint', 'upcast_vae']
# Before the map()
dataset_hash = Hasher.hash(diffusers.training_utils.remove_arguments(args, DATASET_AGNOSTIC_ARGS))

@lhoestq WDYT

sayakpaul avatar May 07 '25 06:05 sayakpaul

or simpler:

dataset_hash = Hasher.hash([args.dataset_name, args.dataset_config_name, args.train_data_dir, args.pretrained_model_name_or_path, args.revision, args.max_sequence_length])

edit: pls double check if there are other parameters to take into account

lhoestq avatar May 07 '25 12:05 lhoestq

@lhoestq looks like you want something more akin to option 2, which I am perfectly fine with, and I agree that the code reads much more cleanly.

Should we include any local data file checks in case users are loading a local dataset? I know we can't cover every use case, but the local data dir should be supported. I'm thinking if we just check number of sample files and file modification date, that should cover 99% of any changes people make to their local datasets. I think something like this would work:

data_hash = Hasher.hash([(sample, os.path.getmtime(os.path.join(args.train_data_dir, sample))) for sample in os.listdir(args.train_data_dir)] if os.path.isdir(args.train_data_dir) else None
dataset_hash = Hasher.hash([args.dataset_name, args.dataset_config_name, args.train_data_dir, args.pretrained_model_name_or_path, args.revision, args.max_sequence_length, data_hash])        

kentdan3msu avatar May 07 '25 17:05 kentdan3msu

The dataset origins (including file modification dates for local files) are already taken into account to compute the original dataset fingerprint, accessible at train_dataset._fingerprint.

So actually you can do

dataset_hash = Hasher.hash([train_dataset._fingerprint, args.pretrained_model_name_or_path, args.revision, args.max_sequence_length])

lhoestq avatar May 12 '25 15:05 lhoestq

I wasn't aware that Dataset objects had their own fingerprint. That makes implementation much easier. I'll test it out on my end and see how well it works.

kentdan3msu avatar May 16 '25 15:05 kentdan3msu

Quick update: the dataset loader I wrote for the project I'm working on does not have a consistent fingerprint, so it keeps recomputing the underlying embeddings every time I rerun the code without modifying the underlying data or the baseline network. However, I'm fairly certain it's due to the way I wrote the dataloader (I'm using a GeneratorBasedBuilder, and I think the underlying generator object that's created at runtime affects the fingerprint).

I'm trying a rewrite of my dataset loader to see if it'll have a consistent fingerprint. If I can get my rewritten loader to give a consistent hash (and also output a different one if underlying files are modified), I'll push ahead on getting the PR for this submitted.

kentdan3msu avatar Jun 02 '25 19:06 kentdan3msu

Hi, GeneratorBasedBuilder is only for internal use (and is likely to get breaking changes in the future). You should use load_dataset() or Dataset.from_generator().

In particular Dataset.from_generator() returns a dataset with a fingerprint that corresponds to a hash of the generator function passed in argument, let me know if that works !

lhoestq avatar Jun 05 '25 12:06 lhoestq