DDPM-Pytorch icon indicating copy to clipboard operation
DDPM-Pytorch copied to clipboard

The reverse sampling results are not ideal

Open 2000lf opened this issue 1 year ago • 1 comments

noise = torch.randn_like(im).to(device)
    t = torch.full((im.shape[0],), diffusion_config['num_timesteps']-1, device=device)
    #t = torch.randint(0, diffusion_config['num_timesteps'], (im.shape[0],)).to(device)
    xt = scheduler.add_noise(im, noise, t)

for i in tqdm(reversed(range(diffusion_config['num_timesteps']))):
        # Get prediction of noise
        noise_pred = model(xt, torch.as_tensor(i).unsqueeze(0).to(device))
        
        # Use scheduler to get x0 and xt-1
        xt, x0_pred = scheduler.sample_prev_timestep(xt, noise, torch.as_tensor(i).to(device))```

I used the xt from the forward process to replace the original random noise, and used the noise added during the forward process to replace the model's output for reverse sampling, in order to validate the reverse sampling process. However, my results are not ideal. Do you have any insights on this

2000lf avatar Dec 14 '24 09:12 2000lf

Hey, I implemented this code as well and when I sample xt-1 for 1000 trained timesteps I get an image with a pixel intensity range that keep increasing. Such that my network becomes unable to denoise it since it's soo out of its learned distribution (~N(0,1)). I think the issue is in the mean computation of the sample_prev_timestep. It might be that xt gets divided by a small value on each run of the function, and since xt-1 becomes xt for the next step, the problem just get worst over the denoising steps. I don't know if that was your issue. Good luck!

Katrinex6 avatar Aug 27 '25 15:08 Katrinex6