VAR
VAR copied to clipboard
Question about function idxBl_to_var_input
# ===================== idxBl_to_var_input: only used in VAR training, for getting teacher-forcing input =====================
def idxBl_to_var_input(self, gt_ms_idx_Bl: List[torch.Tensor]) -> torch.Tensor:
next_scales = []
B = gt_ms_idx_Bl[0].shape[0]
C = self.Cvae
H = W = self.v_patch_nums[-1]
SN = len(self.v_patch_nums)
f_hat = gt_ms_idx_Bl[0].new_zeros(B, C, H, W, dtype=torch.float32)
pn_next: int = self.v_patch_nums[0]
for si in range(SN-1):
if self.prog_si == 0 or (0 <= self.prog_si-1 < si): break # progressive training: not supported yet, prog_si always -1
h_BChw = F.interpolate(self.embedding(gt_ms_idx_Bl[si]).transpose_(1, 2).view(B, C, pn_next, pn_next), size=(H, W), mode='bicubic')
f_hat.add_(self.quant_resi[si/(SN-1)](h_BChw))
pn_next = self.v_patch_nums[si+1]
next_scales.append(F.interpolate(f_hat, size=(pn_next, pn_next), mode='area').view(B, C, -1).transpose(1, 2))
return torch.cat(next_scales, dim=1) if len(next_scales) else None # cat BlCs to BLC, this should be float32
Thanks for the brilliant work! Here is a question. Why next_scales is not direct self.embedding(gt_ms_idx_Bl[si], but F.interpolate(f_hat, size=(pn_next, pn_next), mode='area') ? The term f_hat is added by zero.