diffusers icon indicating copy to clipboard operation
diffusers copied to clipboard

UNet2DModel - ValueError: cross_attention_dim must be specified for CrossAttnDownBlock2D

Open phuongtrau opened this issue 1 year ago • 1 comments

Describe the bug

When I load the model for my training, I just want to use the unconditional model. But I have this problem. Could anyone tell me how to solve this?

ValueError: cross_attention_dim must be specified for CrossAttnDownBlock2D on
self.unet = UNet2DModel.from_pretrained(model_key, subfolder="unet").to(self.device) image

Reproduction

from diffusers import AutoencoderKL, UNet2DModel, DDIMScheduler

import torch import torch.nn as nn import torch.nn.functional as F

class StableDiffusion(nn.Module):

def __init__(self, device, sd_version='1.5', hf_key=None):
    super().__init__()

    self.device = device
    self.sd_version = sd_version
    # self.opt = opt
    t_range = [0.02, 0.98]
    print(f'[INFO] loading stable diffusion...')
    
    if hf_key is not None:
        print(f'[INFO] using hugging face custom model key: {hf_key}')
        model_key = hf_key
    elif self.sd_version == '2.1':
        model_key = "stabilityai/stable-diffusion-2-1-base"
    elif self.sd_version == '2.0':
        model_key = "stabilityai/stable-diffusion-2-base"
    elif self.sd_version == '1.5':
        model_key = "runwayml/stable-diffusion-v1-5"
    else:
        raise ValueError(f'Stable-diffusion version {self.sd_version} not supported.')

    # Create model
    self.vae = AutoencoderKL.from_pretrained(model_key, subfolder="vae").to(self.device)
    self.unet = UNet2DModel.from_pretrained(model_key, subfolder="unet").to(self.device)

    self.scheduler = DDIMScheduler.from_pretrained(model_key, subfolder="scheduler")

    self.num_train_timesteps = self.scheduler.config.num_train_timesteps
    self.min_step = int(self.num_train_timesteps * t_range[0])
    self.max_step = int(self.num_train_timesteps * t_range[1])
    self.alphas = self.scheduler.alphas_cumprod.to(self.device) # for convenience

    print(f'[INFO] loaded stable diffusion!')

def defuse(self, pred_rgb):
    
    # interp to 512x512 to be fed into vae.
    # assert torch.isnan(pred_rgb).sum() == 0, print(pred_rgb)
    input_size = pred_rgb.size
    
    pred_rgb_512 = F.interpolate(pred_rgb, (512, 512), mode='bilinear', align_corners=False)
    
    # encode image into latents with vae, requires grad!
    latents = self.encode_imgs(pred_rgb_512)        

    t = torch.randint(self.min_step, self.max_step + 1, [1], dtype=torch.long, device=self.device)

    # predict the noise residual with unet, NO grad!
    with torch.no_grad():
        # add noise
        noise = torch.randn_like(latents)
        latents_noisy = self.scheduler.add_noise(latents, noise, t)
        
        # pred noise
        latent_model_input = torch.cat([latents_noisy] * 2)
        noise_pred = self.unet(latent_model_input, t).sample
    
    self.w = (1 - self.alphas[t])
    out = self.decode_latents(noise_pred)
    out = F.interpolate(out, input_size, mode='bilinear', align_corners=False)

    return out

Logs

None

System Info

Copy-and-paste the text below in your GitHub issue and FILL OUT the two last points.

  • diffusers version: 0.28.0.dev0
  • Platform: Linux-5.4.0-113-generic-x86_64-with-glibc2.17
  • Python version: 3.8.19
  • PyTorch version (GPU?): 1.13.1 (True)
  • Huggingface_hub version: 0.22.2
  • Transformers version: 4.39.3
  • Accelerate version: 0.29.2
  • xFormers version: not installed
  • Using GPU in script?:
  • Using distributed or parallel set-up in script?:

Who can help?

No response

phuongtrau avatar Apr 16 '24 05:04 phuongtrau

This issue has been automatically marked as stale because it has not had recent activity. If you think this still needs to be addressed please comment on this thread.

Please note that issues that do not follow the contributing guidelines are likely to be ignored.

github-actions[bot] avatar May 16 '24 15:05 github-actions[bot]

This should turned into a discussion as this is NOT a library issue.

sayakpaul avatar Jun 29 '24 13:06 sayakpaul