diffusers icon indicating copy to clipboard operation
diffusers copied to clipboard

The FID of the threetimestep_spacing (leading , linspace, trailing)using ddim timestep varies greatly with the number of total sampling steps(10st 100st)

Open xinding64 opened this issue 2 years ago • 3 comments

Describe the bug

None of our three models using ddim timestep can reproduce the FID in the ddim paper, and the fid of the three models varies greatly with the number of sampling steps. 截图

The following shows the actual step modes of different timesteps.

if self.config.timestep_spacing == "linspace": #[999, 888, 777, 666, 555, 444, 333, 222, 111, 0] timesteps = (np.linspace(0, self.config.num_train_timesteps - 1, num_inference_steps).round()[::-1].copy().astype(np.int64))

elif self.config.timestep_spacing == "leading":#[[900, 800, 700, 600, 500, 400, 300, 200, 100, 0]] step_ratio = self.config.num_train_timesteps // self.num_inference_steps # creates integer timesteps by multiplying by ratio # casting to int to avoid issues when num_inference_step is power of 3 timesteps = (np.arange(0, num_inference_steps) * step_ratio).round()[::-1].copy().astype(np.int64) timesteps += self.config.steps_offset

elif self.config.timestep_spacing == "trailing":#[999, 899, 799, 699, 599, 499, 399, 299, 199, 99] step_ratio = self.config.num_train_timesteps / self.num_inference_steps # creates integer timesteps by multiplying by ratio # casting to int to avoid issues when num_inference_step is power of 3 timesteps = np.round(np.arange(self.config.num_train_timesteps, 0, -step_ratio)).astype(np.int64) timesteps -= 1

Reproduction

parser = argparse.ArgumentParser() parser.add_argument("--total_samples", type=int, default=50000) parser.add_argument("--batch_size", type=int, default=400) parser.add_argument("--output_dir", type=str, default="ddpm_cifar10/ddpm-cifar10_ema_model_result/ddim_real10_50k") parser.add_argument("--model_path", type=str, default="ddpm_cifar10/ckpt/ddpm_ema_cifar10") parser.add_argument("--ddim_steps", type=int, default=10) parser.add_argument("--pruned_model_ckpt", type=str, default=None) parser.add_argument("--seed", type=int, default=0) parser.add_argument("--skip_type", type=str, default="uniform")

args = parser.parse_args()

if name == "main":

os.makedirs(args.output_dir, exist_ok=True)
# pruned model
accelerator = accelerate.Accelerator()
model = UNet2DModel(
    downsample_padding= 0,
    flip_sin_to_cos= False,
    freq_shift= 1,
    in_channels= 3,
    layers_per_block= 2,
    mid_block_scale_factor= 1,
    norm_eps=1e-06,
    norm_num_groups= 32,
    out_channels=3,
    resnet_time_scale_shift= "default",
    sample_size=32,
    time_embedding_type= "positional",
    block_out_channels = (128,256,256,256),
    act_fn="silu",
    add_attention= True,
    down_block_types=(
        "DownBlock2D",  # a regular ResNet downsampling block
        "AttnDownBlock2D",  # a ResNet downsampling block with spatial self-attention
        "DownBlock2D",
        "DownBlock2D",
    ),
    up_block_types=(
        "UpBlock2D",  # a regular ResNet upsampling block
        "UpBlock2D",  # a regular ResNet upsampling block
        "AttnUpBlock2D",  # a ResNet upsampling block with spatial self-attention
        "UpBlock2D"
    ),
).from_pretrained(args.model_path,subfolder = "unet")

pipeline = DDIMPipeline(
    unet=model,
    scheduler=DDIMScheduler.from_pretrained(args.model_path, subfolder="scheduler"))

pipeline.scheduler.skip_type = args.skip_type
# pipeline.scheduler.timestep_spacing = "linspace"
# pipeline.scheduler.set_timesteps(50)
# print(pipeline.scheduler.timesteps)
# exit(0)

# Test Flops
pipeline.to(accelerator.device)
if accelerator.is_main_process:
    if 'cifar' in args.model_path:
        example_inputs = {'sample': torch.randn(1, 3, 32, 32).to(accelerator.device), 'timestep': torch.ones((1,)).long().to(accelerator.device)}
    else:
        example_inputs = {'sample': torch.randn(1, 3, 256, 256).to(accelerator.device), 'timestep': torch.ones((1,)).long().to(accelerator.device)}
    macs, params = tp.utils.count_ops_and_params(pipeline.unet, example_inputs)
    print(f"MACS: {macs/1e9} G, Params: {params/1e6} M")

# Create subfolders for each process
save_sub_dir = os.path.join(args.output_dir, 'process_{}'.format(accelerator.process_index))
os.makedirs(save_sub_dir, exist_ok=True)
generator = torch.Generator(device=pipeline.device).manual_seed(args.seed+accelerator.process_index)

# Set up progress bar
if not accelerator.is_main_process:
    pipeline.set_progress_bar_config(disable=True)

# Sampling
accelerator.wait_for_everyone()
with torch.no_grad():
    # num_batches of each process
    num_batches = (args.total_samples) // (args.batch_size * accelerator.num_processes)
    if accelerator.is_main_process:
        print("Samping {}x{}={} images with {} process(es)".format(num_batches*args.batch_size, accelerator.num_processes, num_batches*accelerator.num_processes*args.batch_size, accelerator.num_processes))
    for i in tqdm(range(num_batches), disable=not accelerator.is_main_process):
        images = pipeline(batch_size=args.batch_size, num_inference_steps=args.ddim_steps, generator=generator).images
        for j, image in enumerate(images):
            filename = os.path.join(save_sub_dir, f"{i * args.batch_size + j}.png")
            image.save(filename)

# Finished
accelerator.wait_for_everyone()
if accelerator.is_main_process:
    accelerator.print(f"Saved {num_batches*accelerator.num_processes*args.batch_size} samples to {args.output_dir}")
#accelerator.end_training()

Logs

No response

System Info

Name: diffusers Version: 0.21.0 Summary: State-of-the-art diffusion in PyTorch and JAX. Home-page: https://github.com/huggingface/diffusers Author: The HuggingFace team Author-email: [email protected] License: Apache

Who can help?

No response

xinding64 avatar Mar 20 '24 03:03 xinding64

Eval is a tricky business. We need to ensure that the eval configurations match apples to apples in order to reproduce those numbers. In most of the cases, this is simply not possible. So, I don't have any concrete suggestions for you.

But I will let @patil-suraj comment further since he has experience running FID evals for aMUSEd.

sayakpaul avatar Mar 20 '24 12:03 sayakpaul

This issue has been automatically marked as stale because it has not had recent activity. If you think this still needs to be addressed please comment on this thread.

Please note that issues that do not follow the contributing guidelines are likely to be ignored.

github-actions[bot] avatar Apr 19 '24 15:04 github-actions[bot]

This issue has been automatically marked as stale because it has not had recent activity. If you think this still needs to be addressed please comment on this thread.

Please note that issues that do not follow the contributing guidelines are likely to be ignored.

github-actions[bot] avatar May 16 '24 15:05 github-actions[bot]

Closing because of inactivity.

sayakpaul avatar Jun 29 '24 13:06 sayakpaul