NaN Loss after Forward Pass on Step 4 when Full Fine-Tuning Wan 2.2 5B
Hi, I am currently full fine-tuning the Wan 2.2 5B model, but the loss becomes NaN right after the forward process at the 4th training step every time.
I have verified that the issue is not related to specific data samples — each training run uses a different random seed, so the order of loaded input videos changes every time, yet the NaN loss still consistently occurs at the same step.
What could be the possible reasons for this NaN loss?
Thanks!
@wyyfffff Please reduce the learning rate. 1e-6 is safe.
@wyyfffff Please reduce the learning rate. 1e-6 is safe.
@Artiprocher Thank you for your reply !
I have already reduced the learning rate to 1e-8, with learning rate warmup and accelerator.clip_grad_norm_(model.trainable_modules(), 1.0) enabled.
However, the strange thing is that when I trained the model on A800 server, everything worked fine, but when I ran the same code on H20, I encountered the NaN loss issue.
I don’t think it’s related to the dataset, because I trained another model on the same dataset without any problems.
I suspect the issue might be related to model initialization, as I added a new patchify layer to the Wan model, which I initialized using kaiming uniform.
Have you tried training on H20 before? Or are there any specific considerations I should be aware of when initializing new layers in the wan model?
Thanks!
@wyyfffff I encountered a similar issue. I reduced the learning rate to 1e-8 and applied gradient clipping with accelerator.clip_grad_norm_(model.trainable_modules(), 1.0). My training uses DeepSpeed ZeRO-2 on four GPUs, and I’ve tested it on two types of devices (RTX 5880 Ada and H800). Interestingly, the training sometimes runs normally on certain GPUs but produces NaNs on others. I wonder have you solved this problem?
@wyyfffff I encountered a similar issue. I reduced the learning rate to 1e-8 and applied gradient clipping with accelerator.clip_grad_norm_(model.trainable_modules(), 1.0). My training uses DeepSpeed ZeRO-2 on four GPUs, and I’ve tested it on two types of devices (RTX 5880 Ada and H800). Interestingly, the training sometimes runs normally on certain GPUs but produces NaNs on others. I wonder have you solved this problem?
Hello, I solved the problem by replacing fetch_model with the traditional initialization method:
dit = WanModel(**config)
dit.load_state_dict()
You might try this
For vae and text_encoder, you can still use fetch_model. But if you have modified the layers in dit, I think it’s better to try my method.
@wyyfffff I encountered a similar issue. I reduced the learning rate to 1e-8 and applied gradient clipping with accelerator.clip_grad_norm_(model.trainable_modules(), 1.0). My training uses DeepSpeed ZeRO-2 on four GPUs, and I’ve tested it on two types of devices (RTX 5880 Ada and H800). Interestingly, the training sometimes runs normally on certain GPUs but produces NaNs on others. I wonder have you solved this problem?
I will try it soon, thanks a lot! @wyyfffff
I have rewritten the model loading logic using traditional initialization method, but the problem persist. Not sure what is wrong.
def load_dit(file_path, model_class, torch_dtype, device):
state_dict = {}
for path in file_path:
state_dict.update(load_state_dict(path))
state_dict_converter = model_class.state_dict_converter()
model_state_dict, extra_kwargs = state_dict_converter.from_civitai(state_dict)
model = model_class(**extra_kwargs)
model = model.to_empty(device=device)
model.load_state_dict(model_state_dict, strict=False)
model = model.to(dtype=torch_dtype, device=device)
return model
dit = load_dit(dit_config.path,
model_class=WanModel,
torch_dtype=model_config.offload_dtype or torch_dtype,
device=model_config.offload_device or device
)
@zzhang2816 Below is my modified code, maybe you can try it:
In my case, I removed the WAN safetensor path from model_configs to create a new wan_paths: list[str]
@staticmethod
def from_pretrained(
wan_paths,
torch_dtype: torch.dtype = torch.bfloat16,
device: Union[str, torch.device] = "cuda",
model_configs: list[ModelConfig] = [],
# tokenizer_config: ModelConfig = ModelConfig(model_id="Wan-AI/Wan2.1-T2V-1.3B", origin_file_pattern="google/*"),
# audio_processor_config: ModelConfig = None,
redirect_common_files: bool = True,
use_usp=False,
):
# Redirect model path
if redirect_common_files:
redirect_dict = {
"models_t5_umt5-xxl-enc-bf16.pth": "Wan-AI/Wan2.1-T2V-1.3B",
"Wan2.1_VAE.pth": "Wan-AI/Wan2.1-T2V-1.3B",
"models_clip_open-clip-xlm-roberta-large-vit-huge-14.pth": "Wan-AI/Wan2.1-I2V-14B-480P",
}
for model_config in model_configs:
if model_config.origin_file_pattern is None or model_config.model_id is None:
continue
if model_config.origin_file_pattern in redirect_dict and model_config.model_id != redirect_dict[model_config.origin_file_pattern]:
print(f"To avoid repeatedly downloading model files, ({model_config.model_id}, {model_config.origin_file_pattern}) is redirected to ({redirect_dict[model_config.origin_file_pattern]}, {model_config.origin_file_pattern}). You can use `redirect_common_files=False` to disable file redirection.")
model_config.model_id = redirect_dict[model_config.origin_file_pattern]
# Initialize pipeline
pipe = WanVideoPipeline(device=device, torch_dtype=torch_dtype)
if use_usp: pipe.initialize_usp()
# Download and load models
model_manager = ModelManager()
for model_config in model_configs:
model_config.download_if_necessary(use_usp=use_usp)
model_manager.load_model(
model_config.path,
# device=model_config.offload_device or device,
device='cpu',
torch_dtype=model_config.offload_dtype or torch_dtype
)
# Load models
# pipe.text_encoder = model_manager.fetch_model("wan_video_text_encoder")
## my load wanmodel
print(f"====== go load wan config ======")
wan22_config_path = 'z_my_wan22_5B_config.json'
with open(wan22_config_path, "r") as f:
config = json.load(f)
dit = WanModel(**config)
print(f"====== go load wan weight ======")
dit_state_dict = {}
for each in wan_paths:
dit_state_dict.update(load_file(each))
missing, unexpected = dit.load_state_dict(dit_state_dict, strict=False)
with torch.no_grad():
miss = set(missing)
for name, p in dit.named_parameters():
if name in miss:
p.zero_()
for name, b in dit.named_buffers():
if name in miss:
if b.is_floating_point() or b.is_complex():
b.zero_()
else:
b.fill_(0)
print(f"====== load wan weight ok!!! ======")
pipe.dit = dit
# print(dit)
# dit = model_manager.fetch_model("wan_video_dit", index=2)
# if isinstance(dit, list):
# pipe.dit, pipe.dit2 = dit
# else:
# pipe.dit = dit
pipe.vae = model_manager.fetch_model("wan_video_vae")
# pipe.image_encoder = model_manager.fetch_model("wan_video_image_encoder")
# pipe.motion_controller = model_manager.fetch_model("wan_video_motion_controller")
# pipe.vace = model_manager.fetch_model("wan_video_vace")
# pipe.audio_encoder = model_manager.fetch_model("wans2v_audio_encoder")
# Size division factor
if pipe.vae is not None:
pipe.height_division_factor = pipe.vae.upsampling_factor * 2
pipe.width_division_factor = pipe.vae.upsampling_factor * 2
# Initialize tokenizer
# tokenizer_config.download_if_necessary(use_usp=use_usp)
# pipe.prompter.fetch_models(pipe.text_encoder)
# pipe.prompter.fetch_tokenizer(tokenizer_config.path)
# if audio_processor_config is not None:
# audio_processor_config.download_if_necessary(use_usp=use_usp)
# from transformers import Wav2Vec2Processor
# pipe.audio_processor = Wav2Vec2Processor.from_pretrained(audio_processor_config.path)
# Unified Sequence Parallel
if use_usp: pipe.enable_usp()
return pipe
Thanks for sharing! I’ll give it a try. @wyyfffff