variational-diffusion-models icon indicating copy to clipboard operation
variational-diffusion-models copied to clipboard

[feat] (vdm.py): add loss parametrization to image prediction

Open relyativist opened this issue 9 months ago • 0 comments

according to doi.org/10.48550/arXiv.2401.06281, 4.2, equation 4.62, image prediction loss can be parametrized by adding the snr_t term to pred_loss . Thus for image prediction in VDM.forward() we change:

 - pred_loss = ((model_out - noise) ** 2).sum((1, 2, 3))
+ snr_t = torch.exp(-gamma_t) 
+ pred_loss = snr_t * ((model_out - x) ** 2).sum((1, 2, 3))

relyativist avatar May 27 '25 12:05 relyativist