DongZhuoBai
DongZhuoBai
> 我看到蒸馏用的loss是: def direct_distill_loss(self, **inputs): self.scheduler.set_timesteps(inputs["num_inference_steps"]) models = {name: getattr(self, name) for name in self.in_iteration_models} for progress_id, timestep in enumerate(self.scheduler.timesteps): timestep = timestep.unsqueeze(0).to(dtype=self.torch_dtype, device=self.device) noise_pred = self.model_fn(**models, **inputs, timestep=timestep, progress_id=progress_id)...
哈喽我最近也在尝试训练 qwen-image distill 的code,想问下您是怎么train起来lora code的,我这边一直无法多卡并行训练导致爆显存,感谢!
> 我看到蒸馏用的loss是: def direct_distill_loss(self, **inputs): self.scheduler.set_timesteps(inputs["num_inference_steps"]) models = {name: getattr(self, name) for name in self.in_iteration_models} for progress_id, timestep in enumerate(self.scheduler.timesteps): timestep = timestep.unsqueeze(0).to(dtype=self.torch_dtype, device=self.device) noise_pred = self.model_fn(**models, **inputs, timestep=timestep, progress_id=progress_id)...