SD3 ControlNet Script (and others?): dataset preprocessing cache depends on unrelated arguments
Describe the bug
When using the SD3 ControlNet training script, the training dataset embeddings are precomputed and the results are given a fingerprint based on the input script arguments, which will cause subsequent runs to use the cached preprocessed dataset instead of recomputing the embeddings, which in my experience takes a while. However, arguments that are completely unrelated to the dataset affect this hash, so any minor changes to runtime arguments can trigger a full dataset remap.
The code in question: https://github.com/huggingface/diffusers/blob/ed4efbd63d0f6b271894bc404b12f512d6b764e5/examples/controlnet/train_controlnet_sd3.py#L1167-L1183
Ideally, the hash should only depend on arguments directly related to the dataset and the embeddings computation function, namely --dataset_config, --dataset_config_name, --pretrained_model_or_path, --variant, and --revision (I could be wrong and other arguments may or may not affect the embeddings, someone please validate).
This issue might affect other example scripts - I haven't looked too deeply at other scripts, but if this technique is used elsewhere, it could be causing similar training startup delays, especially for people using the same dataset over multiple training attempts.
Reproduction
Step 1: Start training an SD3 ControlNet model (specific dataset shouldn't matter, and you can exit the script after the map() is complete and training begins)
python3 examples/controlnet/train_controlnet_sd3.py --pretrained_model_name_or_path=/path/to/your/stable-diffusion-3-medium-diffusers --output_dir=/path/to/output --dataset_name=fusing/fill50k --resolution=1024 --learning_rate=1e-5 --train_batch_size=2 --dataset_preprocess_batch_size=500
Step 2: relaunch with the same arguments. Observe that the map() is skipped and training restarts fairly quickly.
Step 3: relaunch with --train_batch_size reduced to 1. Observe that the map() is restarted.
Proposed fix
For SD3 ControlNet script: either create a copy of the arguments and remove arguments that do not affect the input arguments, or create a new argparse.Namespace() and copy the arguments that matter.
Option 1 would look something like:
import copy
args_copy = copy.deepcopy(args)
for unwanted_arg in ['output_dir', 'train_batch_size', 'dataset_preprocess_batch_size',
'gradient_accumulation_steps', 'gradient_checkpointing', 'learning_rate',
'max_train_steps', 'checkpointing_steps', 'lr_num_cycles',
'validation_prompt', 'validation_image', 'validation_steps']:
if hasattr(args_copy, unwanted_arg):
delattr(args_copy, unwanted_arg)
new_fingerprint = Hasher.hash(args_copy)
Option 2 would look something like:
args_copy = argparse.Namespace()
for dataset_arg in ['dataset_config_name', 'pretrained_model_name_or_path', 'variant', 'revision']:
setattr(args_copy, dataset_arg, getattr(args, dataset_arg))
new_fingerprint = Hasher.hash(args_copy)
Ideally, if the dataset is loaded from a local config, the contents of the config should be hashed instead of the filename itself, but fixing that would be a bit more complex (and might be beyond the scope of the example script).
System Info
- 🤗 Diffusers version: 0.34.0.dev0
- Platform: Linux-5.15.0-91-generic-x86_64-with-glibc2.35
- Running on Google Colab?: No
- Python version: 3.11.11
- PyTorch version (GPU?): 2.6.0+cu124 (True)
- Flax version (CPU?/GPU?/TPU?): not installed (NA)
- Jax version: not installed
- JaxLib version: not installed
- Huggingface_hub version: 0.29.3
- Transformers version: 4.50.0
- Accelerate version: 1.5.2
- PEFT version: 0.11.1
- Bitsandbytes version: not installed
- Safetensors version: 0.5.3
- xFormers version: not installed
- Accelerator: NVIDIA RTX A6000, 49140 MiB
- Using GPU in script?: Yes
- Using distributed or parallel set-up in script?: Tested with both single instance and accelerate with 2+ GPUs
Who can help?
@sayakpaul
Good point and thanks for the detailed issue. I would be in favor of adding a utility function to training_utils.py with option 1. Option 1 reads more explicit to me hence the preference.
Would you like to open a PR for this? This way, your contributions stay in the repo and IMO, better respected.
Sure, I can open a PR. I'll try to have something ready in about a week.
Just so I'm on the right track, would something like this work for the interface at the training script level?
# Somewhere near the top of each training script
DATASET_AGNOSTIC_ARGS=['output_dir', 'train_batch_size', 'dataset_preprocess_batch_size',
'gradient_accumulation_steps', 'gradient_checkpointing', 'learning_rate',
'max_train_steps', 'checkpointing_steps', 'add_lora', 'lr_num_cycles',
'validation_prompt', 'validation_image', 'validation_steps', 'dataset_test_config_name',
'resume_from_checkpoint', 'upcast_vae']
# Before the map()
dataset_hash = Hasher.hash(diffusers.training_utils.remove_arguments(args, DATASET_AGNOSTIC_ARGS))
I can write remove_arguments() so that it can accept either a list of strings, or just strings themselves, i.e.:
new_args = remove_arguments(args, 'argument_1', 'argument_2')
# or
new_args =remove_arguments(args, ['argument_1', 'argument_2'])
Depending on how complete this should be, a corresponding add_arguments() function could also be included.
I think the following should work:
# Somewhere near the top of each training script
DATASET_AGNOSTIC_ARGS=['output_dir', 'train_batch_size', 'dataset_preprocess_batch_size',
'gradient_accumulation_steps', 'gradient_checkpointing', 'learning_rate',
'max_train_steps', 'checkpointing_steps', 'add_lora', 'lr_num_cycles',
'validation_prompt', 'validation_image', 'validation_steps', 'dataset_test_config_name',
'resume_from_checkpoint', 'upcast_vae']
# Before the map()
dataset_hash = Hasher.hash(diffusers.training_utils.remove_arguments(args, DATASET_AGNOSTIC_ARGS))
@lhoestq WDYT
or simpler:
dataset_hash = Hasher.hash([args.dataset_name, args.dataset_config_name, args.train_data_dir, args.pretrained_model_name_or_path, args.revision, args.max_sequence_length])
edit: pls double check if there are other parameters to take into account
@lhoestq looks like you want something more akin to option 2, which I am perfectly fine with, and I agree that the code reads much more cleanly.
Should we include any local data file checks in case users are loading a local dataset? I know we can't cover every use case, but the local data dir should be supported. I'm thinking if we just check number of sample files and file modification date, that should cover 99% of any changes people make to their local datasets. I think something like this would work:
data_hash = Hasher.hash([(sample, os.path.getmtime(os.path.join(args.train_data_dir, sample))) for sample in os.listdir(args.train_data_dir)] if os.path.isdir(args.train_data_dir) else None
dataset_hash = Hasher.hash([args.dataset_name, args.dataset_config_name, args.train_data_dir, args.pretrained_model_name_or_path, args.revision, args.max_sequence_length, data_hash])
The dataset origins (including file modification dates for local files) are already taken into account to compute the original dataset fingerprint, accessible at train_dataset._fingerprint.
So actually you can do
dataset_hash = Hasher.hash([train_dataset._fingerprint, args.pretrained_model_name_or_path, args.revision, args.max_sequence_length])
I wasn't aware that Dataset objects had their own fingerprint. That makes implementation much easier. I'll test it out on my end and see how well it works.
Quick update: the dataset loader I wrote for the project I'm working on does not have a consistent fingerprint, so it keeps recomputing the underlying embeddings every time I rerun the code without modifying the underlying data or the baseline network. However, I'm fairly certain it's due to the way I wrote the dataloader (I'm using a GeneratorBasedBuilder, and I think the underlying generator object that's created at runtime affects the fingerprint).
I'm trying a rewrite of my dataset loader to see if it'll have a consistent fingerprint. If I can get my rewritten loader to give a consistent hash (and also output a different one if underlying files are modified), I'll push ahead on getting the PR for this submitted.
Hi, GeneratorBasedBuilder is only for internal use (and is likely to get breaking changes in the future). You should use load_dataset() or Dataset.from_generator().
In particular Dataset.from_generator() returns a dataset with a fingerprint that corresponds to a hash of the generator function passed in argument, let me know if that works !