DiffSynth-Studio icon indicating copy to clipboard operation
DiffSynth-Studio copied to clipboard

NaN Loss after Forward Pass on Step 4 when Full Fine-Tuning Wan 2.2 5B

Open wyyfffff opened this issue 3 months ago • 9 comments

Hi, I am currently full fine-tuning the Wan 2.2 5B model, but the loss becomes NaN right after the forward process at the 4th training step every time.

I have verified that the issue is not related to specific data samples — each training run uses a different random seed, so the order of loaded input videos changes every time, yet the NaN loss still consistently occurs at the same step.

What could be the possible reasons for this NaN loss?

Thanks!

wyyfffff avatar Oct 09 '25 17:10 wyyfffff

@wyyfffff Please reduce the learning rate. 1e-6 is safe.

Artiprocher avatar Oct 10 '25 06:10 Artiprocher

@wyyfffff Please reduce the learning rate. 1e-6 is safe.

@Artiprocher Thank you for your reply !

I have already reduced the learning rate to 1e-8, with learning rate warmup and accelerator.clip_grad_norm_(model.trainable_modules(), 1.0) enabled.

However, the strange thing is that when I trained the model on A800 server, everything worked fine, but when I ran the same code on H20, I encountered the NaN loss issue.

I don’t think it’s related to the dataset, because I trained another model on the same dataset without any problems.

I suspect the issue might be related to model initialization, as I added a new patchify layer to the Wan model, which I initialized using kaiming uniform.

Have you tried training on H20 before? Or are there any specific considerations I should be aware of when initializing new layers in the wan model?

Thanks!

wyyfffff avatar Oct 10 '25 09:10 wyyfffff

@wyyfffff I encountered a similar issue. I reduced the learning rate to 1e-8 and applied gradient clipping with accelerator.clip_grad_norm_(model.trainable_modules(), 1.0). My training uses DeepSpeed ZeRO-2 on four GPUs, and I’ve tested it on two types of devices (RTX 5880 Ada and H800). Interestingly, the training sometimes runs normally on certain GPUs but produces NaNs on others. I wonder have you solved this problem?

zzhang2816 avatar Oct 24 '25 11:10 zzhang2816

@wyyfffff I encountered a similar issue. I reduced the learning rate to 1e-8 and applied gradient clipping with accelerator.clip_grad_norm_(model.trainable_modules(), 1.0). My training uses DeepSpeed ZeRO-2 on four GPUs, and I’ve tested it on two types of devices (RTX 5880 Ada and H800). Interestingly, the training sometimes runs normally on certain GPUs but produces NaNs on others. I wonder have you solved this problem?

Hello, I solved the problem by replacing fetch_model with the traditional initialization method:

dit = WanModel(**config)
dit.load_state_dict()

You might try this

wyyfffff avatar Oct 24 '25 11:10 wyyfffff

For vae and text_encoder, you can still use fetch_model. But if you have modified the layers in dit, I think it’s better to try my method.

@wyyfffff I encountered a similar issue. I reduced the learning rate to 1e-8 and applied gradient clipping with accelerator.clip_grad_norm_(model.trainable_modules(), 1.0). My training uses DeepSpeed ZeRO-2 on four GPUs, and I’ve tested it on two types of devices (RTX 5880 Ada and H800). Interestingly, the training sometimes runs normally on certain GPUs but produces NaNs on others. I wonder have you solved this problem?

wyyfffff avatar Oct 24 '25 12:10 wyyfffff

I will try it soon, thanks a lot! @wyyfffff

zzhang2816 avatar Oct 25 '25 12:10 zzhang2816

I have rewritten the model loading logic using traditional initialization method, but the problem persist. Not sure what is wrong.

def load_dit(file_path, model_class, torch_dtype, device):
    state_dict = {}
    for path in file_path:
        state_dict.update(load_state_dict(path))

    state_dict_converter = model_class.state_dict_converter()
    model_state_dict, extra_kwargs = state_dict_converter.from_civitai(state_dict)
    model = model_class(**extra_kwargs)
    model = model.to_empty(device=device)
    model.load_state_dict(model_state_dict, strict=False)
    model = model.to(dtype=torch_dtype, device=device)
    return model


dit = load_dit(dit_config.path, 
         model_class=WanModel, 
         torch_dtype=model_config.offload_dtype or torch_dtype,
         device=model_config.offload_device or device
)

zzhang2816 avatar Oct 27 '25 13:10 zzhang2816

@zzhang2816 Below is my modified code, maybe you can try it:

In my case, I removed the WAN safetensor path from model_configs to create a new wan_paths: list[str]

    @staticmethod
    def from_pretrained(
        wan_paths,
        torch_dtype: torch.dtype = torch.bfloat16,
        device: Union[str, torch.device] = "cuda",
        model_configs: list[ModelConfig] = [],
        # tokenizer_config: ModelConfig = ModelConfig(model_id="Wan-AI/Wan2.1-T2V-1.3B", origin_file_pattern="google/*"),
        # audio_processor_config: ModelConfig = None,
        redirect_common_files: bool = True,
        use_usp=False,
    ):
        # Redirect model path
        if redirect_common_files:
            redirect_dict = {
                "models_t5_umt5-xxl-enc-bf16.pth": "Wan-AI/Wan2.1-T2V-1.3B",
                "Wan2.1_VAE.pth": "Wan-AI/Wan2.1-T2V-1.3B",
                "models_clip_open-clip-xlm-roberta-large-vit-huge-14.pth": "Wan-AI/Wan2.1-I2V-14B-480P",
            }
            for model_config in model_configs:
                if model_config.origin_file_pattern is None or model_config.model_id is None:
                    continue
                if model_config.origin_file_pattern in redirect_dict and model_config.model_id != redirect_dict[model_config.origin_file_pattern]:
                    print(f"To avoid repeatedly downloading model files, ({model_config.model_id}, {model_config.origin_file_pattern}) is redirected to ({redirect_dict[model_config.origin_file_pattern]}, {model_config.origin_file_pattern}). You can use `redirect_common_files=False` to disable file redirection.")
                    model_config.model_id = redirect_dict[model_config.origin_file_pattern]
        
        # Initialize pipeline
        pipe = WanVideoPipeline(device=device, torch_dtype=torch_dtype)
        if use_usp: pipe.initialize_usp()
        
        # Download and load models
        model_manager = ModelManager()
        for model_config in model_configs:
            model_config.download_if_necessary(use_usp=use_usp)
            model_manager.load_model(
                model_config.path,
                # device=model_config.offload_device or device,
                device='cpu',
                torch_dtype=model_config.offload_dtype or torch_dtype
            )
        
        # Load models
        # pipe.text_encoder = model_manager.fetch_model("wan_video_text_encoder")
        ## my load wanmodel

        print(f"====== go load wan config ======")
        wan22_config_path = 'z_my_wan22_5B_config.json'
        with open(wan22_config_path, "r") as f:
            config = json.load(f)
        dit = WanModel(**config)

        print(f"====== go load wan weight ======")
        dit_state_dict = {}
        for each in wan_paths:
            dit_state_dict.update(load_file(each))
        missing, unexpected = dit.load_state_dict(dit_state_dict, strict=False)
        with torch.no_grad():
            miss = set(missing)
            for name, p in dit.named_parameters():
                if name in miss:
                    p.zero_()  
            for name, b in dit.named_buffers():
                if name in miss:
                    if b.is_floating_point() or b.is_complex():
                        b.zero_()
                    else:
                        b.fill_(0)  
        print(f"====== load wan weight ok!!! ======")
        pipe.dit = dit
        # print(dit)
        # dit = model_manager.fetch_model("wan_video_dit", index=2)
        # if isinstance(dit, list):
        #     pipe.dit, pipe.dit2 = dit
        # else:
        #     pipe.dit = dit
        pipe.vae = model_manager.fetch_model("wan_video_vae")
        # pipe.image_encoder = model_manager.fetch_model("wan_video_image_encoder")
        # pipe.motion_controller = model_manager.fetch_model("wan_video_motion_controller")
        # pipe.vace = model_manager.fetch_model("wan_video_vace")
        # pipe.audio_encoder = model_manager.fetch_model("wans2v_audio_encoder")

        # Size division factor
        if pipe.vae is not None:
            pipe.height_division_factor = pipe.vae.upsampling_factor * 2
            pipe.width_division_factor = pipe.vae.upsampling_factor * 2

        # Initialize tokenizer
        # tokenizer_config.download_if_necessary(use_usp=use_usp)
        # pipe.prompter.fetch_models(pipe.text_encoder)
        # pipe.prompter.fetch_tokenizer(tokenizer_config.path)

        # if audio_processor_config is not None:
        #     audio_processor_config.download_if_necessary(use_usp=use_usp)
        #     from transformers import Wav2Vec2Processor
        #     pipe.audio_processor = Wav2Vec2Processor.from_pretrained(audio_processor_config.path)
        # Unified Sequence Parallel
        if use_usp: pipe.enable_usp()
        return pipe

wyyfffff avatar Oct 29 '25 06:10 wyyfffff

Thanks for sharing! I’ll give it a try. @wyyfffff

zzhang2816 avatar Oct 31 '25 09:10 zzhang2816