DiffSynth-Studio icon indicating copy to clipboard operation
DiffSynth-Studio copied to clipboard

qwen-image distill模型训练的一些问题

Open 1343744768 opened this issue 5 months ago • 5 comments

您好,我在阅读源码时有一些疑问,想请教一下:

蒸馏模型在训练时是直接通过 SFT(supervised fine-tuning)把 CFG 蒸掉的吗?

是否没有额外的蒸馏约束? 比如类似 DMD: Distilled Model Diffusion 这样的蒸馏方法? 我注意到在代码中,不论是「蒸馏模型」还是「非蒸馏模型」的训练流程,数据预处理后都没有使用 inputs_nega。

那么对于 非蒸馏模型 的训练,是否不应该有一定概率 drop prompt,从而通过 inputs_nega 来生成图像? 目前看起来 inputs_nega 并没有被利用。

1343744768 avatar Aug 27 '25 12:08 1343744768

我看到蒸馏用的loss是: def direct_distill_loss(self, **inputs): self.scheduler.set_timesteps(inputs["num_inference_steps"]) models = {name: getattr(self, name) for name in self.in_iteration_models} for progress_id, timestep in enumerate(self.scheduler.timesteps): timestep = timestep.unsqueeze(0).to(dtype=self.torch_dtype, device=self.device) noise_pred = self.model_fn(**models, **inputs, timestep=timestep, progress_id=progress_id) inputs["latents"] = self.step(self.scheduler, progress_id=progress_id, noise_pred=noise_pred, **inputs) loss = torch.nn.functional.mse_loss(inputs["latents"].float(), inputs["input_latents"].float()) return loss 看起来似乎非常简单粗暴,直接跑规定的扩散步然后和 image latent算loss,但是我实验下来 loss 起伏比较大,不是很稳定,训练了之后few step生成有点提升,但又没法直接用,之前很少见这样直接的蒸馏方法,只是因为这样显存消耗比较大吗?

Jason-Chi-xx avatar Sep 19 '25 11:09 Jason-Chi-xx

我看到蒸馏用的loss是: def direct_distill_loss(self, **inputs): self.scheduler.set_timesteps(inputs["num_inference_steps"]) models = {name: getattr(self, name) for name in self.in_iteration_models} for progress_id, timestep in enumerate(self.scheduler.timesteps): timestep = timestep.unsqueeze(0).to(dtype=self.torch_dtype, device=self.device) noise_pred = self.model_fn(**models, **inputs, timestep=timestep, progress_id=progress_id) inputs["latents"] = self.step(self.scheduler, progress_id=progress_id, noise_pred=noise_pred, **inputs) loss = torch.nn.functional.mse_loss(inputs["latents"].float(), inputs["input_latents"].float()) return loss 看起来似乎非常简单粗暴,直接跑规定的扩散步然后和 image latent算loss,但是我实验下来 loss 起伏比较大,不是很稳定,训练了之后few step生成有点提升,但又没法直接用,之前很少见这样直接的蒸馏方法,只是因为这样显存消耗比较大吗?

zhuobaidong avatar Sep 22 '25 14:09 zhuobaidong

哈喽我最近也在尝试训练 qwen-image distill 的code,想问下您是怎么train起来lora code的,我这边一直无法多卡并行训练导致爆显存,感谢!

zhuobaidong avatar Sep 22 '25 14:09 zhuobaidong

我看到蒸馏用的loss是: def direct_distill_loss(self, **inputs): self.scheduler.set_timesteps(inputs["num_inference_steps"]) models = {name: getattr(self, name) for name in self.in_iteration_models} for progress_id, timestep in enumerate(self.scheduler.timesteps): timestep = timestep.unsqueeze(0).to(dtype=self.torch_dtype, device=self.device) noise_pred = self.model_fn(**models, **inputs, timestep=timestep, progress_id=progress_id) inputs["latents"] = self.step(self.scheduler, progress_id=progress_id, noise_pred=noise_pred, **inputs) loss = torch.nn.functional.mse_loss(inputs["latents"].float(), inputs["input_latents"].float()) return loss 看起来似乎非常简单粗暴,直接跑规定的扩散步然后和 image latent算loss,但是我实验下来 loss 起伏比较大,不是很稳定,训练了之后few step生成有点提升,但又没法直接用,之前很少见这样直接的蒸馏方法,只是因为这样显存消耗比较大吗?

想问下您大概用了多少data train 的,训练后 few step 没法直接用是指生图效果人眼看着不是很好吗

zhuobaidong avatar Sep 30 '25 09:09 zhuobaidong

我看到蒸馏用的loss是: def direct_distill_loss(self, **inputs): self.scheduler.set_timesteps(inputs["num_inference_steps"]) models = {name: getattr(self, name) for name in self.in_iteration_models} for progress_id, timestep in enumerate(self.scheduler.timesteps): timestep = timestep.unsqueeze(0).to(dtype=self.torch_dtype, device=self.device) noise_pred = self.model_fn(**models, **inputs, timestep=timestep, progress_id=progress_id) inputs["latents"] = self.step(self.scheduler, progress_id=progress_id, noise_pred=noise_pred, **inputs) loss = torch.nn.functional.mse_loss(inputs["latents"].float(), inputs["input_latents"].float()) return loss 看起来似乎非常简单粗暴,直接跑规定的扩散步然后和 image latent算loss,但是我实验下来 loss 起伏比较大,不是很稳定,训练了之后few step生成有点提升,但又没法直接用,之前很少见这样直接的蒸馏方法,只是因为这样显存消耗比较大吗?

想问下您大概用了多少data train 的,训练后 few step 没法直接用是指生图效果人眼看着不是很好吗 请问有后续么,复现官方得,单台8卡A100的话,大概需要多长时间

min-star avatar Oct 21 '25 08:10 min-star