[Bug] The low-noise model weights trained in WAN2.2 are then used for further low-noise VSA training, and the loss becomes Nan.
Describe the bug
The low-noise model weights trained in WAN2.2 are then used for further low-noise VSA training, and the loss becomes Nan. Our investigation revealed that the Nan was generated during the backpropagation of VSA.
usr/local/lib/python3.12/dist-packages/torch/autograd/graph.py:824: UserWarning: Error detected in GeneratedBackwardFor_vsa_block_sparse_attn_triton_defaultBackward. Traceback of forward call that caused the error: File "/data/juicefs_sharing_data/public_data/11140845/code/11.5/vivolm-flash/vivolm_flash/train.py", line 102, in <module> sys.exit(recipe_main()) File "/data/juicefs_sharing_data/public_data/11140845/code/11.5/vivolm-flash/torchtune/config/_parse.py", line 99, in wrapper sys.exit(recipe_main(conf)) File "/data/juicefs_sharing_data/public_data/11140845/code/11.5/vivolm-flash/vivolm_flash/train.py", line 98, in recipe_main recipe.train() File "/data/juicefs_sharing_data/public_data/11140845/code/11.5/vivolm-flash/vivolm_flash/v_recipes/flash_train.py", line 464, in train current_loss = self._model(batch) File "/usr/local/lib/python3.12/dist-packages/torch/nn/modules/module.py", line 1751, in _wrapped_call_impl return self._call_impl(*args, **kwargs) File "/usr/local/lib/python3.12/dist-packages/torch/nn/modules/module.py", line 1762, in _call_impl return forward_call(*args, **kwargs) File "/usr/local/lib/python3.12/dist-packages/torch/distributed/fsdp/fully_sharded_data_parallel.py", line 854, in forward output = self._fsdp_wrapped_module(*args, **kwargs) File "/usr/local/lib/python3.12/dist-packages/torch/nn/modules/module.py", line 1751, in _wrapped_call_impl return self._call_impl(*args, **kwargs) File "/usr/local/lib/python3.12/dist-packages/torch/nn/modules/module.py", line 1762, in _call_impl return forward_call(*args, **kwargs) File "/data/vjuicefs_ai_camera_pgroup_ql/public_data/VideoGeneration/CineArt/CineArt-0.4/code/video_generation_aio/examples/wanvideo/model_training/train_ming.py", line 128, in forward loss = self.pipe.training_loss(**models, **inputs) # 去噪和计算loss File "/data/juicefs_sharing_data/public_data/11140845/code/11.5/vivolm-flash/vivolm_flash/modules/attention/vsa.py", line 214, in new_vsa_module_training_loss noise_pred = self.model_fn(**inputs, timestep=timestep) File "/data/vjuicefs_ai_camera_pgroup_ql/public_data/VideoGeneration/CineArt/CineArt-0.4/code/video_generation_aio/diffsynth/pipelines/wan_video_new_ming.py", line 1196, in model_fn_wan_video x = torch.utils.checkpoint.checkpoint( File "/usr/local/lib/python3.12/dist-packages/torch/_compile.py", line 51, in inner return disable_fn(*args, **kwargs) File "/usr/local/lib/python3.12/dist-packages/torch/_dynamo/eval_frame.py", line 749, in _fn return fn(*args, **kwargs) File "/usr/local/lib/python3.12/dist-packages/torch/utils/checkpoint.py", line 495, in checkpoint ret = function(*args, **kwargs) File "/data/vjuicefs_ai_camera_pgroup_ql/public_data/VideoGeneration/CineArt/CineArt-0.4/code/video_generation_aio/diffsynth/pipelines/wan_video_new_ming.py", line 1184, in custom_forward return module(*inputs) File "/usr/local/lib/python3.12/dist-packages/torch/nn/modules/module.py", line 1751, in _wrapped_call_impl return self._call_impl(*args, **kwargs) File "/usr/local/lib/python3.12/dist-packages/torch/nn/modules/module.py", line 1762, in _call_impl return forward_call(*args, **kwargs) File "/usr/local/lib/python3.12/dist-packages/torch/distributed/fsdp/fully_sharded_data_parallel.py", line 854, in forward output = self._fsdp_wrapped_module(*args, **kwargs) File "/usr/local/lib/python3.12/dist-packages/torch/nn/modules/module.py", line 1751, in _wrapped_call_impl return self._call_impl(*args, **kwargs) File "/usr/local/lib/python3.12/dist-packages/torch/nn/modules/module.py", line 1762, in _call_impl return forward_call(*args, **kwargs) File "/data/vjuicefs_ai_camera_pgroup_ql/public_data/VideoGeneration/CineArt/CineArt-0.4/code/video_generation_aio/diffsynth/models/wan_video_dit.py", line 226, in forward x = self.gate(x, gate_msa, self.self_attn(input_x, freqs)) File "/usr/local/lib/python3.12/dist-packages/torch/nn/modules/module.py", line 1751, in _wrapped_call_impl return self._call_impl(*args, **kwargs) File "/usr/local/lib/python3.12/dist-packages/torch/nn/modules/module.py", line 1762, in _call_impl return forward_call(*args, **kwargs) File "/data/juicefs_sharing_data/public_data/11140845/code/11.5/vivolm-flash/vivolm_flash/parallel/sequence_parallel/sequence_parallel.py", line 156, in new_forward output = video_sparse_attn( File "/usr/local/lib/python3.12/dist-packages/vsa/__init__.py", line 73, in video_sparse_attn output_select, _ = block_sparse_attn(q, k, v, block_mask, variable_block_sizes) File "/usr/local/lib/python3.12/dist-packages/torch/_library/custom_ops.py", line 671, in __call__ return self._opoverload(*args, **kwargs) File "/usr/local/lib/python3.12/dist-packages/torch/_ops.py", line 756, in __call__ return self._op(*args, **kwargs) File "/usr/local/lib/python3.12/dist-packages/torch/_library/autograd.py", line 111, in autograd_impl result = Generated.apply(*args, Metadata(keyset, keyword_only_args)) # type: ignore[attr-defined] File "/usr/local/lib/python3.12/dist-packages/torch/autograd/function.py", line 578, in apply return super().apply(*args, **kwargs) # type: ignore[misc] (Triggered internally at /opt/pytorch/pytorch/torch/csrc/autograd/python_anomaly_mode.cpp:122.) return Variable._execution_engine.run_backward( # Calls into the C++ engine to run the backward pass [rank7]: Traceback (most recent call last): [rank7]: File "/data/juicefs_sharing_data/public_data/11140845/code/11.5/vivolm-flash/vivolm_flash/train.py", line 102, in <module> [rank7]: sys.exit(recipe_main()) [rank7]: ^^^^^^^^^^^^^ [rank7]: File "/data/juicefs_sharing_data/public_data/11140845/code/11.5/vivolm-flash/torchtune/config/_parse.py", line 99, in wrapper [rank7]: sys.exit(recipe_main(conf)) [rank7]: ^^^^^^^^^^^^^^^^^ [rank7]: File "/data/juicefs_sharing_data/public_data/11140845/code/11.5/vivolm-flash/vivolm_flash/train.py", line 98, in recipe_main [rank7]: recipe.train() [rank7]: File "/data/juicefs_sharing_data/public_data/11140845/code/11.5/vivolm-flash/vivolm_flash/v_recipes/flash_train.py", line 468, in train [rank7]: current_loss.backward() [rank7]: File "/usr/local/lib/python3.12/dist-packages/torch/_tensor.py", line 648, in backward [rank7]: torch.autograd.backward( [rank7]: File "/usr/local/lib/python3.12/dist-packages/torch/autograd/__init__.py", line 353, in backward [rank7]: _engine_run_backward( [rank7]: File "/usr/local/lib/python3.12/dist-packages/torch/autograd/graph.py", line 824, in _engine_run_backward [rank7]: return Variable._execution_engine.run_backward( # Calls into the C++ engine to run the backward pass [rank7]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ [rank7]: RuntimeError: Function 'GeneratedBackwardFor_vsa_block_sparse_attn_triton_defaultBackward' returned nan values in its 0th output.
However, training a high-noise VSA model using a pre-trained high-noise model (e.g., WAN2.2) works normally. The only differences are the initial model weights and the noise range. Have any of you encountered this problem?
Reproduction
low-noise model : Wan2.2-Fun-A14B-Control:high_noise_model/diffusion_pytorch_model.safetensors high-noise model : Wan2.2-Fun-A14B-Control:high_noise_model/diffusion_pytorch_model.safetensors
Environment
L40s
torch : 2.7.0
Could you share a bit more about what you are trying to train? Are you trying to train VSA + Wan2.2 moe model?
Could you share a bit more about what you are trying to train? Are you trying to train VSA + Wan2.2 moe model?
Yes, I used WAN2.2 + VSA for training. WAN2.2 training is divided into two phases: a high-noise model training phase and a low-noise model training phase. Then, we performed co-inference during the inference phase. I trained VSA on top of the already trained WAN2.2 model (besides loading weights, I also added gate compression and random initialization). High-noise training went smoothly, but low-noise training resulted in NA errors immediately.
do you have a branch with your scripts? I think there may be a bug in our VSA triton kernel bwd, but I'm not sure if this is the bug you are running into. It might be worth trying on the thunderkitten implementation, which only supports hopper GPUs to see if you experience the same issue. cc @jzhang38
Hopper GPU ok, I'll try out the Thunderkitten at H800 first,However, there's a strange phenomenon here: the high-noise model trains very well. The high-noise and low-noise models are the same DIT model; the only differences in training are the initial weights and the noise range. The high-noise model can be trained normally. Additionally, I'm using the open-source WAN2.2 project and have incorporated FastVideo's VSA functionality. Without training (i.e., loading the pre-trained high-noise and low-noise model weights from WAN2.2 and then randomly initializing a gate_compress liner), it can perform inference normally, but the results are not very good. Therefore, we want to further train it.
do you have a branch with your scripts? I think there may be a bug in our VSA triton kernel bwd, but I'm not sure if this is the bug you are running into. It might be worth trying on the thunderkitten implementation, which only supports hopper GPUs to see if you experience the same issue. cc @jzhang38你有包含脚本的分支吗?我觉得我们的VSA triton内核反向传播可能有个bug,但不确定这是不是你遇到的那个问题。或许可以试试thunderkitten的实现,它只支持Hopper GPU,看看你是否会遇到同样的问题。抄送 @jzhang38
thunderkitten only support H100?
Just for reference, the bug mentioned by @SolitaryThinker was just fixed in the PR (https://github.com/hao-ai-lab/FastVideo/pull/879). The new version of VSA Triton kernel might also be worth trying.
Just for reference, the bug mentioned by @SolitaryThinker was just fixed in the PR (#879). The new version of VSA Triton kernel might also be worth trying. I'll try out the Thunderkitten at H800 and it's work。 but the new version of VSA(0.0.4)still has the NaN issue. It seems the new changes might not be in the latest version, and when I tried to compile and install it from source, I encountered the following error:
FAILED: /data/juicefs_sharing_data/public_data/11140845/code/11.5/FastVideo-main/csrc/attn/video_sparse_attn/build/temp.linux-x86_64-cpython-312/vsa/block_sparse_h100.o /usr/local/cuda/bin/nvcc --generate-dependencies-with-compile --dependency-output /data/juicefs_sharing_data/public_data/11140845/code/11.5/FastVideo-main/csrc/attn/video_sparse_attn/build/temp.linux-x86_64-cpython-312/vsa/block_sparse_h100.o.d -I/usr/local/lib/python3.12/dist-packages/torch/include -I/usr/local/lib/python3.12/dist-packages/torch/include/torch/csrc/api/include -I/usr/local/cuda/include -I/usr/include/python3.12 -c -c /data/juicefs_sharing_data/public_data/11140845/code/11.5/FastVideo-main/csrc/attn/video_sparse_attn/vsa/block_sparse_h100.cu -o /data/juicefs_sharing_data/public_data/11140845/code/11.5/FastVideo-main/csrc/attn/video_sparse_attn/build/temp.linux-x86_64-cpython-312/vsa/block_sparse_h100.o -D__CUDA_NO_HALF_OPERATORS__ -D__CUDA_NO_HALF_CONVERSIONS__ -D__CUDA_NO_BFLOAT16_CONVERSIONS__ -D__CUDA_NO_HALF2_OPERATORS__ --expt-relaxed-constexpr --compiler-options ''"'"'-fPIC'"'"'' -DNDEBUG -Xcompiler=-Wno-psabi -Xcompiler=-fno-strict-aliasing --expt-extended-lambda --expt-relaxed-constexpr -forward-unknown-to-host-compiler --use_fast_math -std=c++20 -O3 -Xnvlink=--verbose -Xptxas=--verbose -Xptxas=--warn-on-spills -I/data/juicefs_sharing_data/public_data/11140845/code/11.5/FastVideo-main/csrc/attn/video_sparse_attn/tk/include -I/data/juicefs_sharing_data/public_data/11140845/code/11.5/FastVideo-main/csrc/attn/video_sparse_attn/tk/prototype -I/usr/include/python3.12 -DTORCH_COMPILE -I/usr/local/lib/python3.12/dist-packages/torch/include -I/usr/local/lib/python3.12/dist-packages/torch/include/torch/csrc/api/include -DKITTENS_HOPPER -arch=sm_90a -DTORCH_API_INCLUDE_EXTENSION_H '-DPYBIND11_COMPILER_TYPE="_gcc"' '-DPYBIND11_STDLIB="_libstdcpp"' '-DPYBIND11_BUILD_ABI="_cxxabi1016"' -DTORCH_EXTENSION_NAME=vsa_cuda -D_GLIBCXX_USE_CXX11_ABI=1 /data/juicefs_sharing_data/public_data/11140845/code/11.5/FastVideo-main/csrc/attn/video_sparse_attn/vsa/block_sparse_h100.cu:3:10: fatal error: kittens.cuh: No such file or directory 3 | #include "kittens.cuh" | ^~~~~~~~~~~~~ compilation terminated. ninja: build stopped: subcommand failed. Traceback (most recent call last): File "/usr/local/lib/python3.12/dist-packages/torch/utils/cpp_extension.py", line 2240, in _run_ninja_build subprocess.run( File "/usr/lib/python3.12/subprocess.py", line 571, in run raise CalledProcessError(retcode, process.args, subprocess.CalledProcessError: Command '['ninja', '-v']' returned non-zero exit status 1.Do you have any good solutions? thanks