DiffSynth-Studio icon indicating copy to clipboard operation
DiffSynth-Studio copied to clipboard

inquiry for details of wan2.2 i2v lora training?

Open shinxg opened this issue 5 months ago • 2 comments

Hi, DiffSynth-Studio team. I am recently playing with wan2.2 i2v training and found that the token replace style i2v training is newly adopted in wan2.2. The input_latents' is first initialized with a random noise and the first frame of it is replaced by the input image's latent here. https://github.com/modelscope/DiffSynth-Studio/blob/fa36739f01bac495eaabeb7d1df27f69b4f5a0d9/diffsynth/pipelines/wan_video_new.py#L658

but later when computing training loss, the first frame is replaced with the noised version of the target video' latent here. https://github.com/modelscope/DiffSynth-Studio/blob/fa36739f01bac495eaabeb7d1df27f69b4f5a0d9/diffsynth/pipelines/wan_video_new.py#L81

which seems to be different from the inference pipeline of the original wan2.2 pipeline where the first frame is alway replaced with the clean version of the input image's latent. https://github.com/modelscope/DiffSynth-Studio/blob/fa36739f01bac495eaabeb7d1df27f69b4f5a0d9/diffsynth/pipelines/wan_video_new.py#L461C17-L461C36

I also checked HunyuanI2V's token replace setting and found that there seems to be difference between HuanyuanI2V and wan2.2 i2v. In HunyuanI2V, the training and inference pipeline are the same and the first frame's latent is alway a clean version of the input image. https://github.com/Tencent-Hunyuan/HunyuanVideo-I2V/blob/1481c1d5ae88e9905f54f2a3c6a1b68ef2a10528/hyvideo/diffusion/flow/transport.py#L189 And loss function is only performed on frames other than the first frame.

So I am wondering whether the discrepancy between training and inference code is carefully designed in wan2.2 or there might be a mistake in the training code?

Looking forward to your reply. Thanks!

shinxg avatar Aug 14 '25 02:08 shinxg

Same issue here, have you figured it out?

J-Oyasumi avatar Nov 10 '25 20:11 J-Oyasumi

I modify the training loss function in https://github.com/modelscope/DiffSynth-Studio/blob/8332ecebb70664606e7ff9c4c11feee276744b14/diffsynth/pipelines/wan_video_new.py#L113

def training_loss(self, **inputs):
        max_timestep_boundary = int(inputs.get("max_timestep_boundary", 1) * self.scheduler.num_train_timesteps)
        min_timestep_boundary = int(inputs.get("min_timestep_boundary", 0) * self.scheduler.num_train_timesteps)
        timestep_id = torch.randint(min_timestep_boundary, max_timestep_boundary, (1,))
        timestep = self.scheduler.timesteps[timestep_id].to(dtype=self.torch_dtype, device=self.device)
        
        # 需要保存第一帧,并在添加噪声和计算loss时排除第一帧
        fuse_vae_embedding = inputs.get("fuse_vae_embedding_in_latents", False)
        first_frame_latents = None
        if fuse_vae_embedding and "first_frame_latents" in inputs:
            first_frame_latents = inputs["first_frame_latents"]
        
        # 添加噪声
        inputs["latents"] = self.scheduler.add_noise(inputs["input_latents"], inputs["noise"], timestep)
        # 对于I2V任务,第一帧不应该添加噪声,保持原始值
        if fuse_vae_embedding and first_frame_latents is not None:
            inputs["latents"][:, :, 0:1] = first_frame_latents
        
        training_target = self.scheduler.training_target(inputs["input_latents"], inputs["noise"], timestep)
        
        noise_pred = self.model_fn(**inputs, timestep=timestep)
        
        # 计算loss时,排除第一帧(第一帧是条件,不需要预测)
        if fuse_vae_embedding and first_frame_latents is not None:
            # 只对第一帧之后的帧计算loss
            loss = torch.nn.functional.mse_loss(
                noise_pred[:, :, 1:].float(), 
                training_target[:, :, 1:].float()
            )
        else:
            # 正常计算所有帧的loss
            loss = torch.nn.functional.mse_loss(noise_pred.float(), training_target.float())
        loss = loss * self.scheduler.training_weight(timestep)
        return loss

Is my code right ?

fuyuchenIfyw avatar Nov 17 '25 05:11 fuyuchenIfyw