InstantMesh
InstantMesh copied to clipboard
cfg in zero123++ finetune
I have some issues with the classifier-free guidance of zero123++fine-tuning code. In model.py,
# classifier-free guidance
if np.random.rand() < self.drop_cond_prob:
prompt_embeds = self.pipeline._encode_prompt([""] * B, self.device, 1, False)
cond_latents = self.encode_condition_image(torch.zeros_like(cond_imgs))
else:
prompt_embeds = self.forward_vision_encoder(cond_imgs)
cond_latents = self.encode_condition_image(cond_imgs)
latents = self.encode_target_images(target_imgs)
noise = torch.randn_like(latents)
latents_noisy = self.train_scheduler.add_noise(latents, noise, t)
v_pred = self.forward_unet(latents_noisy, t, prompt_embeds, cond_latents)
v_target = self.get_v(latents, noise, t)
loss, loss_dict = self.compute_loss(v_pred, v_target)
This does not seem to implement the weighted sum of cfg, such as
v_pred = v_pred_uncond + guidance_scale * (v_pred_cond - v_pred_uncond)?
In pipneline.py,
def forward_cond(self, noisy_cond_lat, timestep, encoder_hidden_states, class_labels, ref_dict, is_cfg_guidance, **kwargs):
if is_cfg_guidance:
encoder_hidden_states = encoder_hidden_states[1:]
class_labels = class_labels[1:]
self.unet(
noisy_cond_lat, timestep,
encoder_hidden_states=encoder_hidden_states,
class_labels=class_labels,
cross_attention_kwargs=dict(mode="w", ref_dict=ref_dict),
**kwargs
)
def forward(
self, sample, timestep, encoder_hidden_states, class_labels=None,
*args, cross_attention_kwargs,
down_block_res_samples=None, mid_block_res_sample=None,
**kwargs
):
cond_lat = cross_attention_kwargs['cond_lat']
is_cfg_guidance = cross_attention_kwargs.get('is_cfg_guidance', False)
print("is_cfg_guidance:",is_cfg_guidance)
noise = torch.randn_like(cond_lat)
if self.training:
noisy_cond_lat = self.train_sched.add_noise(cond_lat, noise, timestep)
noisy_cond_lat = self.train_sched.scale_model_input(noisy_cond_lat, timestep)
else:
noisy_cond_lat = self.val_sched.add_noise(cond_lat, noise, timestep.reshape(-1))
noisy_cond_lat = self.val_sched.scale_model_input(noisy_cond_lat, timestep.reshape(-1))
ref_dict = {}
self.forward_cond(
noisy_cond_lat, timestep,
encoder_hidden_states, class_labels,
ref_dict, is_cfg_guidance, **kwargs
)
weight_dtype = self.unet.dtype
return self.unet(
sample, timestep,
encoder_hidden_states, *args,
class_labels=class_labels,
cross_attention_kwargs=dict(mode="r", ref_dict=ref_dict, is_cfg_guidance=is_cfg_guidance),
down_block_additional_residuals=[
sample.to(dtype=weight_dtype) for sample in down_block_res_samples
] if down_block_res_samples is not None else None,
mid_block_additional_residual=(
mid_block_res_sample.to(dtype=weight_dtype)
if mid_block_res_sample is not None else None
),
**kwargs
)
It seems that is_cfg_guidance has never been used in cross_attention_kwargs, and is_cfg_guidance=False.
I would greatly appreciate it if you could explain the implementation of cfg in this code. Thank you!
I think this is being handled internally by using diffusers.StableDiffusionPipeline as the super class