Open-Sora icon indicating copy to clipboard operation
Open-Sora copied to clipboard

The training problem

Open f1yfisher opened this issue 1 year ago • 10 comments

I found that the number of output channels of STDiT-v3 is 8, but in the training phase, only the former 4 channels are computed in the loss function.

f1yfisher avatar Jun 25 '24 06:06 f1yfisher

Thanks. Would you please help me pinpoint where you identified this issue? I saw this but failed to locate where channel selection happened.

JThh avatar Jul 02 '24 23:07 JThh

The training of open sora v1.2 used the rectified flow scheduler in this. On line 102-107, the only first 4 channels are computed in the loss function.

f1yfisher avatar Jul 03 '24 03:07 f1yfisher

In the DiT training code, the purpose of setting sigma is to learn the mean and variance of the noise, and then calculate the KL loss with the gold standard. The code sets pred_sigma to True by default, but directly uses the mean (the first 4 channels) as the predicted noise.

narrowsnap avatar Jul 09 '24 12:07 narrowsnap

The model released is also 8 channels (including mean and variance). Why set the model to predict mean and variance, but do not calculate the loss of variance.

f1yfisher avatar Jul 09 '24 12:07 f1yfisher

The model released is also 8 channels (including mean and variance). Why set the model to predict mean and variance, but do not calculate the loss of variance.

I also want to know. @JThh Is there any reason to do this?

narrowsnap avatar Jul 11 '24 11:07 narrowsnap

This issue is stale because it has been open for 7 days with no activity.

github-actions[bot] avatar Jul 19 '24 01:07 github-actions[bot]

In the DiT training code, the purpose of setting sigma is to learn the mean and variance of the noise, and then calculate the KL loss with the gold standard. The code sets pred_sigma to True by default, but directly uses the mean (the first 4 channels) as the predicted noise.

so why?the output dimensions is? input is ?

henbucuoshanghai avatar Jul 19 '24 05:07 henbucuoshanghai

I felt the same doubt. why Just use out_channel=8 ,uses the mean (the first 4 channels) ,dorpout bias? Maybe,Update latent_z by grad without bias

      dt = timesteps[i] - timesteps[i + 1] if i < len(timesteps) - 1 else timesteps[i]
      dt = dt / self.num_timesteps
      z = z + v_pred * dt[:, None, None, None, None]

xesdiny avatar Jul 27 '24 11:07 xesdiny

I felt the same doubt. why Just use out_channel=8 ,uses the mean (the first 4 channels) ,dorpout bias? Maybe,Update latent_z by grad without bias

      dt = timesteps[i] - timesteps[i + 1] if i < len(timesteps) - 1 else timesteps[i]
      dt = dt / self.num_timesteps
      z = z + v_pred * dt[:, None, None, None, None]

This issue come from pixart-alpha which align with the original DiT. (https://github.com/PixArt-alpha/PixArt-sigma/issues/81#issuecomment-2100610843)

narrowsnap avatar Aug 07 '24 07:08 narrowsnap

i see

henbucuoshanghai avatar Aug 07 '24 08:08 henbucuoshanghai

This issue is stale because it has been open for 7 days with no activity.

github-actions[bot] avatar Sep 17 '24 01:09 github-actions[bot]

This issue was closed because it has been inactive for 7 days since being marked as stale.

github-actions[bot] avatar Sep 25 '24 01:09 github-actions[bot]