Paint-by-Example icon indicating copy to clipboard operation
Paint-by-Example copied to clipboard

question about training input

Open D222097 opened this issue 2 years ago • 0 comments

Nice work! I'm wondering that why the input can be set this way during training?

image_GT????? inpaint_image inpaint_mask ref_imgs
img masked_img msk ref_img

In this work, I found that the inputs are gt(add noise), masked_img, mask and ref_img. As follows, the input x_start to unet is concatenated by z(encode on gt), z_inpaint(encode on masked_img) and mask_resize(downsampling mask):

z_new = torch.cat((z,z_inpaint,mask_resize),dim=1)  # x_start
def p_losses(self, x_start, cond, t, noise=None, ):
    if self.first_stage_key == 'inpaint':
        # x_start=x_start[:,:4,:,:]
        noise = default(noise, lambda: torch.randn_like(x_start[:,:4,:,:]))
        x_noisy = self.q_sample(x_start=x_start[:,:4,:,:], t=t, noise=noise)
        x_noisy = torch.cat((x_noisy,x_start[:,4:,:,:]),dim=1)
    ...
    model_output = self.apply_model(x_noisy, t, cond)
    ...

    if self.parameterization == "x0":
        target = x_start
    elif self.parameterization == "eps":
        target = noise
    else:
        raise NotImplementedError()

    loss_simple = self.get_loss(model_output, target, mean=False).mean([1, 2, 3])
    ...

I am curious about why GT image can be input into the unet directly. Even though it has been added with noise, it is still visible to the unet.

Use the images above as an example: the input is car image, and the expected output is car image during training. And when comes for infering, the input is image unrelated to car(arbitrary object or just background), and the expected output is car image.

This is a little weird. On the one hand, model needs GT to be optimized, and it is often used as a target in other generative model, rather than as a direct input to the model. On the other hand, diffusion model usually do not predict pixels but Gaussian noise, there seems to be no other way for diffusion model to be constrained from gt. I don't know how to understand how model learns, I'd be grateful if anyone could give me advice

D222097 avatar Jan 08 '24 11:01 D222097