FastVideo
FastVideo copied to clipboard
[Bug] num_latent_t causes shape mismatch
Describe the bug
All videos in my training data have the same frame number 81, which is the recommended frame number for WAN2.1 in their official repo. Because WAN VAE has a compression ratio of 4, I set num_latent_t as 21 both for preprocessing and training.
Specifically, for preprocessing I did this:
#!/bin/bash
GPU_NUM=1
MODEL_PATH="Wan-AI/Wan2.1-I2V-14B-480P-Diffusers"
MODEL_TYPE="wan"
DATA_MERGE_PATH="data/agibot_sub/merge.txt"
OUTPUT_DIR="data/agibot_sub_processed_i2v/"
VALIDATION_PATH="examples/training/finetune/wan_agibot/validation.json"
torchrun --nproc_per_node=$GPU_NUM \
fastvideo/v1/pipelines/preprocess/v1_preprocess.py \
--model_path $MODEL_PATH \
--data_merge_path $DATA_MERGE_PATH \
--preprocess_video_batch_size 16 \
--max_height 480 \
--max_width 640 \
--num_frames 81 \
--dataloader_num_workers 0 \
--output_dir=$OUTPUT_DIR \
--model_type $MODEL_TYPE \
--train_fps 16 \
--samples_per_file 8 \
--flush_frequency 8 \
--num_latent_t 21 \
--video_length_tolerance_range 1 \
--preprocess_task "i2v" \
--validation_dataset_file $VALIDATION_PATH
However, the training will give me error like this:
Traceback (most recent call last):
File "/<PROJECT_ROOT>/fastvideo/v1/training/wan_i2v_training_pipeline.py", line 232, in <module>
main(args)
File "/<PROJECT_ROOT>/fastvideo/v1/training/wan_i2v_training_pipeline.py", line 219, in main
pipeline.train()
File "/<ENV_ROOT>/lib/python3.12/site-packages/torch/utils/_contextlib.py", line 116, in decorate_context
return func(*args, **kwargs)
File "/<PROJECT_ROOT>/fastvideo/v1/training/training_pipeline.py", line 443, in train
self._log_validation(self.transformer, self.training_args, 1)
File "/<ENV_ROOT>/lib/python3.12/site-packages/torch/utils/_contextlib.py", line 116, in decorate_context
return func(*args, **kwargs)
File "/<PROJECT_ROOT>/fastvideo/v1/pipelines/composed_pipeline_base.py", line 294, in forward
batch = stage(batch, fastvideo_args)
File "/<PROJECT_ROOT>/fastvideo/v1/pipelines/stages/base.py", line 168, in __call__
result = self.forward(batch, fastvideo_args)
File "/<PROJECT_ROOT>/fastvideo/v1/pipelines/stages/denoising.py", line 107, in forward
latents = rearrange(batch.latents,
File "/<ENV_ROOT>/lib/python3.12/site-packages/einops/einops.py", line 600, in rearrange
return reduce(tensor, pattern, reduction="rearrange", **axes_lengths)
File "/<ENV_ROOT>/lib/python3.12/site-packages/einops/einops.py", line 542, in reduce
raise EinopsError(message + "\n {}".format(e))
einops.EinopsError: Error while processing rearrange-reduction pattern "b t (n s) h w -> b t n s h w".
Input tensor shape: torch.Size([1, 16, 21, 60, 80]). Additional info: {'n': 8}.
Shape mismatch, can't divide axis of length 21 in chunks of 8
I checked the code and saw n here is the sp_world_size.
How can I resolve this issue?
Reproduction
Following wani2v finetuning example but using my own dataset with all 81-frame videos
Environment
fastvideo
For now, num_latent_t should be divisible by sp_size. If you want to use sp=8, you may need to use 61 frames, which results in num_latent_t = 16.