rFID of VAR is very large
I’ve noticed that many people say the rFID tested with the checkpoint provided by VAR is 0.92. However, when I followed the tutorial provided by VAR, the rFID I got was 2.70. I’m not sure what the reason is.
Firstly, we load imagenet val dataset, and resize by center_crop_arr (provided in LLamaGEN). We then obtain the reconstruction image, and saved in PNG format by Image.fromarray(sample).save(f"{sample_folder_dir}/{index:06d}.png"). Finally, we process saved reconstrution images to a npz file by create_npz_from_sample_folder.
import os import os.path as osp import torch, torchvision import random import numpy as np import PIL.Image as PImage, PIL.ImageDraw as PImageDraw import torch torch.backends.cuda.matmul.allow_tf32 = True torch.backends.cudnn.allow_tf32 = True import torch.nn.functional as F import torch.distributed as dist from torch.utils.data import DataLoader from torchvision.datasets import ImageFolder from torch.utils.data.distributed import DistributedSampler from torchvision import transforms from tqdm import tqdm from PIL import Image import numpy as np import argparse import itertools
from models import VQVAE, build_vae_var from metric import PSNR, LPIPS, SSIM from augmentation import center_crop_arr
def create_npz_from_sample_folder(sample_dir, num=50000): """ Builds a single .npz file from a folder of .png samples. """ samples = [] for i in tqdm(range(num), desc="Building .npz file from samples"): sample_pil = Image.open(f"{sample_dir}/{i:06d}.png") sample_np = np.asarray(sample_pil).astype(np.uint8) samples.append(sample_np) samples = np.stack(samples) assert samples.shape == (num, samples.shape[1], samples.shape[2], 3) npz_path = f"{sample_dir}.npz" np.savez(npz_path, arr_0=samples) print(f"Saved .npz file to {npz_path} [shape={samples.shape}].") return npz_path
def load_dataset(data_path, batch_size=16): transform = transforms.Compose([ transforms.Lambda(lambda pil_image: center_crop_arr(pil_image, 256)), transforms.ToTensor(), transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], inplace=True) ]) val_dataset = ImageFolder(root=os.path.join(data_path, 'val'), transform=transform) len_val_set = len(val_dataset) dataloader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False, num_workers=6, drop_last=False) return dataloader, len_val_set
def main(): sample_folder_dir = "/projects/yuanai/processed_data/rFID/baselines/VAR" save_npz_name = "var_reconstruction_imagenet256.npz"
### load dataset
data_path = "/projects/yuanai/data/ImageNet/"
val_dataloader, len_val_set = load_dataset(data_path, batch_size=16)
num_fid_samples = 50000
###load the vae checkpoint
vae, var = build_vae_var(
V=4096, Cvae=32, ch=160, share_quant_resi=4, # hard-coded VQVAE hyperparameters
device='cuda', patch_nums= (1, 2, 3, 4, 5, 6, 8, 10, 13, 16),
num_classes=1000, depth=16, shared_aln=False,
)
vae_ckpt = "/projects/yuanai/processed_data/checkpoint/VAR/vae_ch160v4096z32.pth"
vae.load_state_dict(torch.load(vae_ckpt, map_location='cpu'), strict=True)
vae = vae.cuda()
vae.eval()
psnr_metric = PSNR()
ssim_metric = SSIM()
lpips_metric = LPIPS()
ssim, psnr, lpips = 0.0, 0.0, 0.0
total = 0
for idx, (x, _) in enumerate(val_dataloader):
x = x.cuda()
with torch.no_grad():
x_rec = vae.img_to_reconstructed_img(x, v_patch_nums=(1, 2, 3, 4, 5, 6, 8, 10, 13, 16), last_one=True)
batch_lpips = lpips_metric(x, x_rec).sum()
samples = torch.clamp(127.5 * x_rec + 128.0, 0, 255).permute(0, 2, 3, 1).to("cpu", dtype=torch.uint8).numpy()
# Save samples to disk as individual .png files
for i, sample in enumerate(samples):
index = i + total
Image.fromarray(sample).save(f"{sample_folder_dir}/{index:06d}.png")
total += 16
x_norm = (x + 1.0)/2.0
x_rec_norm = (x_rec + 1.0)/2.0
batch_psnr = psnr_metric(x_norm, x_rec_norm).sum()
batch_ssim = ssim_metric(x_norm, x_rec_norm).sum()
ssim += batch_ssim.item()
psnr += batch_psnr.item()
lpips += batch_lpips.item()
eval_psnr = psnr/len_val_set
eval_ssim = ssim/len_val_set
eval_lpips = lpips/len_val_set
print("PSNR:"+str(eval_psnr)+" SSIM:"+str(eval_ssim)+ " LPIPS:"+str(eval_lpips))
create_npz_from_sample_folder(sample_folder_dir, num_fid_samples)
if name == "main": main()
Based the npz.file, we calculate rFID by Open AI toolkit by "python evaluator.py VIRTUAL_imagenet256_labeled.npz our_sampled_imagenet256.npz", then the calculated rFID is: Inception Score: 56.86065673828125 FID: 2.70789722116308 sFID: 4.6903826389984715 Precision: 0.74194 Recall: 0.6662
Can anyone tell me why there is a problem with the rFID test results?