The FID of the threetimestep_spacing (leading , linspace, trailing)using ddim timestep varies greatly with the number of total sampling steps(10st 100st)
Describe the bug
None of our three models using ddim timestep can reproduce the FID in the ddim paper, and the fid of the three models varies greatly with the number of sampling steps.
The following shows the actual step modes of different timesteps.
if self.config.timestep_spacing == "linspace": #[999, 888, 777, 666, 555, 444, 333, 222, 111, 0] timesteps = (np.linspace(0, self.config.num_train_timesteps - 1, num_inference_steps).round()[::-1].copy().astype(np.int64))
elif self.config.timestep_spacing == "leading":#[[900, 800, 700, 600, 500, 400, 300, 200, 100, 0]] step_ratio = self.config.num_train_timesteps // self.num_inference_steps # creates integer timesteps by multiplying by ratio # casting to int to avoid issues when num_inference_step is power of 3 timesteps = (np.arange(0, num_inference_steps) * step_ratio).round()[::-1].copy().astype(np.int64) timesteps += self.config.steps_offset
elif self.config.timestep_spacing == "trailing":#[999, 899, 799, 699, 599, 499, 399, 299, 199, 99] step_ratio = self.config.num_train_timesteps / self.num_inference_steps # creates integer timesteps by multiplying by ratio # casting to int to avoid issues when num_inference_step is power of 3 timesteps = np.round(np.arange(self.config.num_train_timesteps, 0, -step_ratio)).astype(np.int64) timesteps -= 1
Reproduction
parser = argparse.ArgumentParser() parser.add_argument("--total_samples", type=int, default=50000) parser.add_argument("--batch_size", type=int, default=400) parser.add_argument("--output_dir", type=str, default="ddpm_cifar10/ddpm-cifar10_ema_model_result/ddim_real10_50k") parser.add_argument("--model_path", type=str, default="ddpm_cifar10/ckpt/ddpm_ema_cifar10") parser.add_argument("--ddim_steps", type=int, default=10) parser.add_argument("--pruned_model_ckpt", type=str, default=None) parser.add_argument("--seed", type=int, default=0) parser.add_argument("--skip_type", type=str, default="uniform")
args = parser.parse_args()
if name == "main":
os.makedirs(args.output_dir, exist_ok=True)
# pruned model
accelerator = accelerate.Accelerator()
model = UNet2DModel(
downsample_padding= 0,
flip_sin_to_cos= False,
freq_shift= 1,
in_channels= 3,
layers_per_block= 2,
mid_block_scale_factor= 1,
norm_eps=1e-06,
norm_num_groups= 32,
out_channels=3,
resnet_time_scale_shift= "default",
sample_size=32,
time_embedding_type= "positional",
block_out_channels = (128,256,256,256),
act_fn="silu",
add_attention= True,
down_block_types=(
"DownBlock2D", # a regular ResNet downsampling block
"AttnDownBlock2D", # a ResNet downsampling block with spatial self-attention
"DownBlock2D",
"DownBlock2D",
),
up_block_types=(
"UpBlock2D", # a regular ResNet upsampling block
"UpBlock2D", # a regular ResNet upsampling block
"AttnUpBlock2D", # a ResNet upsampling block with spatial self-attention
"UpBlock2D"
),
).from_pretrained(args.model_path,subfolder = "unet")
pipeline = DDIMPipeline(
unet=model,
scheduler=DDIMScheduler.from_pretrained(args.model_path, subfolder="scheduler"))
pipeline.scheduler.skip_type = args.skip_type
# pipeline.scheduler.timestep_spacing = "linspace"
# pipeline.scheduler.set_timesteps(50)
# print(pipeline.scheduler.timesteps)
# exit(0)
# Test Flops
pipeline.to(accelerator.device)
if accelerator.is_main_process:
if 'cifar' in args.model_path:
example_inputs = {'sample': torch.randn(1, 3, 32, 32).to(accelerator.device), 'timestep': torch.ones((1,)).long().to(accelerator.device)}
else:
example_inputs = {'sample': torch.randn(1, 3, 256, 256).to(accelerator.device), 'timestep': torch.ones((1,)).long().to(accelerator.device)}
macs, params = tp.utils.count_ops_and_params(pipeline.unet, example_inputs)
print(f"MACS: {macs/1e9} G, Params: {params/1e6} M")
# Create subfolders for each process
save_sub_dir = os.path.join(args.output_dir, 'process_{}'.format(accelerator.process_index))
os.makedirs(save_sub_dir, exist_ok=True)
generator = torch.Generator(device=pipeline.device).manual_seed(args.seed+accelerator.process_index)
# Set up progress bar
if not accelerator.is_main_process:
pipeline.set_progress_bar_config(disable=True)
# Sampling
accelerator.wait_for_everyone()
with torch.no_grad():
# num_batches of each process
num_batches = (args.total_samples) // (args.batch_size * accelerator.num_processes)
if accelerator.is_main_process:
print("Samping {}x{}={} images with {} process(es)".format(num_batches*args.batch_size, accelerator.num_processes, num_batches*accelerator.num_processes*args.batch_size, accelerator.num_processes))
for i in tqdm(range(num_batches), disable=not accelerator.is_main_process):
images = pipeline(batch_size=args.batch_size, num_inference_steps=args.ddim_steps, generator=generator).images
for j, image in enumerate(images):
filename = os.path.join(save_sub_dir, f"{i * args.batch_size + j}.png")
image.save(filename)
# Finished
accelerator.wait_for_everyone()
if accelerator.is_main_process:
accelerator.print(f"Saved {num_batches*accelerator.num_processes*args.batch_size} samples to {args.output_dir}")
#accelerator.end_training()
Logs
No response
System Info
Name: diffusers Version: 0.21.0 Summary: State-of-the-art diffusion in PyTorch and JAX. Home-page: https://github.com/huggingface/diffusers Author: The HuggingFace team Author-email: [email protected] License: Apache
Who can help?
No response
Eval is a tricky business. We need to ensure that the eval configurations match apples to apples in order to reproduce those numbers. In most of the cases, this is simply not possible. So, I don't have any concrete suggestions for you.
But I will let @patil-suraj comment further since he has experience running FID evals for aMUSEd.
This issue has been automatically marked as stale because it has not had recent activity. If you think this still needs to be addressed please comment on this thread.
Please note that issues that do not follow the contributing guidelines are likely to be ignored.
This issue has been automatically marked as stale because it has not had recent activity. If you think this still needs to be addressed please comment on this thread.
Please note that issues that do not follow the contributing guidelines are likely to be ignored.
Closing because of inactivity.