ComfyUI-layerdiffuse icon indicating copy to clipboard operation
ComfyUI-layerdiffuse copied to clipboard

[Bug]: Expected all tensors to be on the same device, but found at least two devices, cuda:0 and cpu! (when checking argument for argument mat2 in method wrapper_CUDA_mm)

Open designex opened this issue 1 year ago • 4 comments

What happened?

微信截图_20240826183956

Steps to reproduce the problem

After using the generated image once, an error message will pop up

What should have happened?

After using the generated image once, an error message will pop up

Commit where the problem happens

ComfyUI: ComfyUI-layerdiffuse:

Sysinfo

none

Console logs

Error occurred when executing KSampler (Efficient):

Expected all tensors to be on the same device, but found at least two devices, cuda:0 and cpu! (when checking argument for argument mat2 in method wrapper_CUDA_mm)

File "D:\ComyUI2024\ComfyUI\execution.py", line 152, in recursive_execute
output_data, output_ui = get_output_data(obj, input_data_all)
File "D:\ComyUI2024\ComfyUI\execution.py", line 82, in get_output_data
return_values = map_node_over_list(obj, input_data_all, obj.FUNCTION, allow_interrupt=True)
File "D:\ComyUI2024\ComfyUI\execution.py", line 75, in map_node_over_list
results.append(getattr(obj, func)(**slice_dict(input_data_all, i)))
File "D:\ComyUI2024\ComfyUI\custom_nodes\efficiency-nodes-comfyui\efficiency_nodes.py", line 732, in sample
samples, images, gifs, preview = process_latent_image(model, seed, steps, cfg, sampler_name, scheduler,
File "D:\ComyUI2024\ComfyUI\custom_nodes\efficiency-nodes-comfyui\efficiency_nodes.py", line 550, in process_latent_image
samples = KSampler().sample(model, seed, steps, cfg, sampler_name, scheduler, positive, negative,
File "D:\ComyUI2024\ComfyUI\nodes.py", line 1382, in sample
return common_ksampler(model, seed, steps, cfg, sampler_name, scheduler, positive, negative, latent_image, denoise=denoise)
File "D:\ComyUI2024\ComfyUI\nodes.py", line 1352, in common_ksampler
samples = comfy.sample.sample(model, noise, steps, cfg, sampler_name, scheduler, positive, negative, latent_image,
File "D:\ComyUI2024\ComfyUI\custom_nodes\ComfyUI-Impact-Pack\modules\impact\sample_error_enhancer.py", line 22, in informative_sample
raise e
File "D:\ComyUI2024\ComfyUI\custom_nodes\ComfyUI-Impact-Pack\modules\impact\sample_error_enhancer.py", line 9, in informative_sample
return original_sample(*args, **kwargs) # This code helps interpret error messages that occur within exceptions but does not have any impact on other operations.
File "D:\ComyUI2024\ComfyUI\custom_nodes\ComfyUI-AnimateDiff-Evolved\animatediff\sampling.py", line 434, in motion_sample
return orig_comfy_sample(model, noise, *args, **kwargs)
File "D:\ComyUI2024\ComfyUI\custom_nodes\ComfyUI-Advanced-ControlNet\adv_control\sampling.py", line 116, in acn_sample
return orig_comfy_sample(model, *args, **kwargs)
File "D:\ComyUI2024\ComfyUI\custom_nodes\ComfyUI-Advanced-ControlNet\adv_control\utils.py", line 116, in uncond_multiplier_check_cn_sample
return orig_comfy_sample(model, *args, **kwargs)
File "D:\ComyUI2024\ComfyUI\comfy\sample.py", line 43, in sample
samples = sampler.sample(noise, positive, negative, cfg=cfg, latent_image=latent_image, start_step=start_step, last_step=last_step, force_full_denoise=force_full_denoise, denoise_mask=noise_mask, sigmas=sigmas, callback=callback, disable_pbar=disable_pbar, seed=seed)
File "D:\ComyUI2024\ComfyUI\comfy\samplers.py", line 829, in sample
return sample(self.model, noise, positive, negative, cfg, self.device, sampler, sigmas, self.model_options, latent_image=latent_image, denoise_mask=denoise_mask, callback=callback, disable_pbar=disable_pbar, seed=seed)
File "D:\ComyUI2024\ComfyUI\comfy\samplers.py", line 729, in sample
return cfg_guider.sample(noise, latent_image, sampler, sigmas, denoise_mask, callback, disable_pbar, seed)
File "D:\ComyUI2024\ComfyUI\comfy\samplers.py", line 716, in sample
output = self.inner_sample(noise, latent_image, device, sampler, sigmas, denoise_mask, callback, disable_pbar, seed)
File "D:\ComyUI2024\ComfyUI\comfy\samplers.py", line 695, in inner_sample
samples = sampler.sample(self, sigmas, extra_args, callback, noise, latent_image, denoise_mask, disable_pbar)
File "D:\ComyUI2024\ComfyUI\comfy\samplers.py", line 600, in sample
samples = self.sampler_function(model_k, noise, sigmas, extra_args=extra_args, callback=k_callback, disable=disable_pbar, **self.extra_options)
File "D:\ComyUI2024\ComfyUI\python\lib\site-packages\torch\utils\_contextlib.py", line 115, in decorate_context
return func(*args, **kwargs)
File "D:\ComyUI2024\ComfyUI\comfy\k_diffusion\sampling.py", line 143, in sample_euler
denoised = model(x, sigma_hat * s_in, **extra_args)
File "D:\ComyUI2024\ComfyUI\comfy\samplers.py", line 299, in __call__
out = self.inner_model(x, sigma, model_options=model_options, seed=seed)
File "D:\ComyUI2024\ComfyUI\comfy\samplers.py", line 682, in __call__
return self.predict_noise(*args, **kwargs)
File "D:\ComyUI2024\ComfyUI\comfy\samplers.py", line 685, in predict_noise
return sampling_function(self.inner_model, x, timestep, self.conds.get("negative", None), self.conds.get("positive", None), self.cfg, model_options=model_options, seed=seed)
File "D:\ComyUI2024\ComfyUI\comfy\samplers.py", line 279, in sampling_function
out = calc_cond_batch(model, conds, x, timestep, model_options)
File "D:\ComyUI2024\ComfyUI\comfy\samplers.py", line 228, in calc_cond_batch
output = model.apply_model(input_x, timestep_, **c).chunk(batch_chunks)
File "D:\ComyUI2024\ComfyUI\custom_nodes\ComfyUI-Advanced-ControlNet\adv_control\utils.py", line 68, in apply_model_uncond_cleanup_wrapper
return orig_apply_model(self, *args, **kwargs)
File "D:\ComyUI2024\ComfyUI\comfy\model_base.py", line 145, in apply_model
model_output = self.diffusion_model(xc, t, context=context, control=control, transformer_options=transformer_options, **extra_conds).float()
File "D:\ComyUI2024\ComfyUI\python\lib\site-packages\torch\nn\modules\module.py", line 1518, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
File "D:\ComyUI2024\ComfyUI\python\lib\site-packages\torch\nn\modules\module.py", line 1527, in _call_impl
return forward_call(*args, **kwargs)
File "D:\ComyUI2024\ComfyUI\custom_nodes\FreeU_Advanced\nodes.py", line 176, in __temp__forward
h = forward_timestep_embed(module, h, emb, context, transformer_options, time_context=time_context, num_video_frames=num_video_frames, image_only_indicator=image_only_indicator)
File "D:\ComyUI2024\ComfyUI\comfy\ldm\modules\diffusionmodules\openaimodel.py", line 44, in forward_timestep_embed
x = layer(x, context, transformer_options)
File "D:\ComyUI2024\ComfyUI\python\lib\site-packages\torch\nn\modules\module.py", line 1518, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
File "D:\ComyUI2024\ComfyUI\python\lib\site-packages\torch\nn\modules\module.py", line 1527, in _call_impl
return forward_call(*args, **kwargs)
File "D:\ComyUI2024\ComfyUI\comfy\ldm\modules\attention.py", line 694, in forward
x = block(x, context=context[i], transformer_options=transformer_options)
File "D:\ComyUI2024\ComfyUI\python\lib\site-packages\torch\nn\modules\module.py", line 1518, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
File "D:\ComyUI2024\ComfyUI\python\lib\site-packages\torch\nn\modules\module.py", line 1527, in _call_impl
return forward_call(*args, **kwargs)
File "D:\ComyUI2024\ComfyUI\custom_nodes\ComfyUI-layerdiffuse\lib_layerdiffusion\attention_sharing.py", line 253, in forward
return func(self, x, context, transformer_options)
File "D:\ComyUI2024\ComfyUI\comfy\ldm\modules\attention.py", line 581, in forward
n = self.attn1(n, context=context_attn1, value=value_attn1)
File "D:\ComyUI2024\ComfyUI\python\lib\site-packages\torch\nn\modules\module.py", line 1518, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
File "D:\ComyUI2024\ComfyUI\python\lib\site-packages\torch\nn\modules\module.py", line 1527, in _call_impl
return forward_call(*args, **kwargs)
File "D:\ComyUI2024\ComfyUI\comfy\ldm\modules\attention.py", line 465, in forward
q = self.to_q(x)
File "D:\ComyUI2024\ComfyUI\python\lib\site-packages\torch\nn\modules\module.py", line 1518, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
File "D:\ComyUI2024\ComfyUI\python\lib\site-packages\torch\nn\modules\module.py", line 1527, in _call_impl
return forward_call(*args, **kwargs)
File "D:\ComyUI2024\ComfyUI\comfy\ops.py", line 65, in forward
return super().forward(*args, **kwargs)
File "D:\ComyUI2024\ComfyUI\python\lib\site-packages\torch\nn\modules\linear.py", line 114, in forward
return F.linear(input, self.weight, self.bias)

Workflow json file

workflow (1).json

Additional information

No response

designex avatar Aug 26 '24 10:08 designex

same here

HajunKim avatar Sep 11 '24 08:09 HajunKim

我也遇到这样的问题

2467646299 avatar Feb 17 '25 12:02 2467646299

same here

dajuanmao0224 avatar Mar 05 '25 11:03 dajuanmao0224

In comfyui-layerdiffuse\lib_layerdiffusion\attention_sharing.py Add this code to fix. I am not going to create a pull request. @huchenlei


# Currently only sd15

import functools
import torch
import einops

from comfy import model_management
from comfy.ldm.modules.attention import optimized_attention
from comfy.model_patcher import ModelPatcher


module_mapping_sd15 = {
    0: "input_blocks.1.1.transformer_blocks.0.attn1",
    1: "input_blocks.1.1.transformer_blocks.0.attn2",
    2: "input_blocks.2.1.transformer_blocks.0.attn1",
    3: "input_blocks.2.1.transformer_blocks.0.attn2",
    4: "input_blocks.4.1.transformer_blocks.0.attn1",
    5: "input_blocks.4.1.transformer_blocks.0.attn2",
    6: "input_blocks.5.1.transformer_blocks.0.attn1",
    7: "input_blocks.5.1.transformer_blocks.0.attn2",
    8: "input_blocks.7.1.transformer_blocks.0.attn1",
    9: "input_blocks.7.1.transformer_blocks.0.attn2",
    10: "input_blocks.8.1.transformer_blocks.0.attn1",
    11: "input_blocks.8.1.transformer_blocks.0.attn2",
    12: "output_blocks.3.1.transformer_blocks.0.attn1",
    13: "output_blocks.3.1.transformer_blocks.0.attn2",
    14: "output_blocks.4.1.transformer_blocks.0.attn1",
    15: "output_blocks.4.1.transformer_blocks.0.attn2",
    16: "output_blocks.5.1.transformer_blocks.0.attn1",
    17: "output_blocks.5.1.transformer_blocks.0.attn2",
    18: "output_blocks.6.1.transformer_blocks.0.attn1",
    19: "output_blocks.6.1.transformer_blocks.0.attn2",
    20: "output_blocks.7.1.transformer_blocks.0.attn1",
    21: "output_blocks.7.1.transformer_blocks.0.attn2",
    22: "output_blocks.8.1.transformer_blocks.0.attn1",
    23: "output_blocks.8.1.transformer_blocks.0.attn2",
    24: "output_blocks.9.1.transformer_blocks.0.attn1",
    25: "output_blocks.9.1.transformer_blocks.0.attn2",
    26: "output_blocks.10.1.transformer_blocks.0.attn1",
    27: "output_blocks.10.1.transformer_blocks.0.attn2",
    28: "output_blocks.11.1.transformer_blocks.0.attn1",
    29: "output_blocks.11.1.transformer_blocks.0.attn2",
    30: "middle_block.1.transformer_blocks.0.attn1",
    31: "middle_block.1.transformer_blocks.0.attn2",
}


def compute_cond_mark(cond_or_uncond, sigmas):
    cond_or_uncond_size = int(sigmas.shape[0])

    cond_mark = []
    for cx in cond_or_uncond:
        cond_mark += [cx] * cond_or_uncond_size

    cond_mark = torch.Tensor(cond_mark).to(sigmas)
    return cond_mark


class LoRALinearLayer(torch.nn.Module):
    def __init__(self, in_features: int, out_features: int, rank: int = 256, org=None):
        super().__init__()
        self.down = torch.nn.Linear(in_features, rank, bias=False)
        self.up = torch.nn.Linear(rank, out_features, bias=False)
        self.org = [org]

    def forward(self, h):
        org_weight = self.org[0].weight.to(h)
        org_bias = self.org[0].bias.to(h) if self.org[0].bias is not None else None
        down_weight = self.down.weight.to(h)
        up_weight = self.up.weight.to(h)
        final_weight = org_weight + torch.mm(up_weight, down_weight)
        return torch.nn.functional.linear(h, final_weight, org_bias)


class AttentionSharingUnit(torch.nn.Module):
    # `transformer_options` passed to the most recent BasicTransformerBlock.forward
    # call.
    transformer_options: dict = {}

    def __init__(self, module, frames=2, use_control=True, rank=256):
        super().__init__()

        self.heads = module.heads
        self.frames = frames
        self.original_module = [module]
        q_in_channels, q_out_channels = (
            module.to_q.in_features,
            module.to_q.out_features,
        )
        k_in_channels, k_out_channels = (
            module.to_k.in_features,
            module.to_k.out_features,
        )
        v_in_channels, v_out_channels = (
            module.to_v.in_features,
            module.to_v.out_features,
        )
        o_in_channels, o_out_channels = (
            module.to_out[0].in_features,
            module.to_out[0].out_features,
        )

        hidden_size = k_out_channels

        self.to_q_lora = [
            LoRALinearLayer(q_in_channels, q_out_channels, rank, module.to_q)
            for _ in range(self.frames)
        ]
        self.to_k_lora = [
            LoRALinearLayer(k_in_channels, k_out_channels, rank, module.to_k)
            for _ in range(self.frames)
        ]
        self.to_v_lora = [
            LoRALinearLayer(v_in_channels, v_out_channels, rank, module.to_v)
            for _ in range(self.frames)
        ]
        self.to_out_lora = [
            LoRALinearLayer(o_in_channels, o_out_channels, rank, module.to_out[0])
            for _ in range(self.frames)
        ]

        self.to_q_lora = torch.nn.ModuleList(self.to_q_lora)
        self.to_k_lora = torch.nn.ModuleList(self.to_k_lora)
        self.to_v_lora = torch.nn.ModuleList(self.to_v_lora)
        self.to_out_lora = torch.nn.ModuleList(self.to_out_lora)

        self.temporal_i = torch.nn.Linear(
            in_features=hidden_size, out_features=hidden_size
        )
        self.temporal_n = torch.nn.LayerNorm(
            hidden_size, elementwise_affine=True, eps=1e-6
        )
        self.temporal_q = torch.nn.Linear(
            in_features=hidden_size, out_features=hidden_size
        )
        self.temporal_k = torch.nn.Linear(
            in_features=hidden_size, out_features=hidden_size
        )
        self.temporal_v = torch.nn.Linear(
            in_features=hidden_size, out_features=hidden_size
        )
        self.temporal_o = torch.nn.Linear(
            in_features=hidden_size, out_features=hidden_size
        )

        self.control_convs = None

        if use_control:
            self.control_convs = [
                torch.nn.Sequential(
                    torch.nn.Conv2d(256, 256, kernel_size=3, padding=1, stride=1),
                    torch.nn.SiLU(),
                    torch.nn.Conv2d(256, hidden_size, kernel_size=1),
                )
                for _ in range(self.frames)
            ]
            self.control_convs = torch.nn.ModuleList(self.control_convs)

        self.control_signals = None

    def forward(self, h, context=None, value=None):
        self.to(h.device)
        transformer_options = self.transformer_options

        modified_hidden_states = einops.rearrange(
            h, "(b f) d c -> f b d c", f=self.frames
        )

        if self.control_convs is not None:
            context_dim = int(modified_hidden_states.shape[2])
            control_outs = []
            for f in range(self.frames):
                control_signal = self.control_signals[context_dim].to(
                    modified_hidden_states
                )
                control = self.control_convs[f](control_signal)
                control = einops.rearrange(control, "b c h w -> b (h w) c")
                control_outs.append(control)
            control_outs = torch.stack(control_outs, dim=0)
            modified_hidden_states = modified_hidden_states + control_outs.to(
                modified_hidden_states
            )

        if context is None:
            framed_context = modified_hidden_states
        else:
            framed_context = einops.rearrange(
                context, "(b f) d c -> f b d c", f=self.frames
            )

        framed_cond_mark = einops.rearrange(
            compute_cond_mark(
                transformer_options["cond_or_uncond"],
                transformer_options["sigmas"],
            ),
            "(b f) -> f b",
            f=self.frames,
        ).to(modified_hidden_states)

        attn_outs = []
        for f in range(self.frames):
            fcf = framed_context[f]

            if context is not None:
                cond_overwrite = transformer_options.get("cond_overwrite", [])
                if len(cond_overwrite) > f:
                    cond_overwrite = cond_overwrite[f]
                else:
                    cond_overwrite = None
                if cond_overwrite is not None:
                    cond_mark = framed_cond_mark[f][:, None, None]
                    fcf = cond_overwrite.to(fcf) * (1.0 - cond_mark) + fcf * cond_mark

            q = self.to_q_lora[f](modified_hidden_states[f])
            k = self.to_k_lora[f](fcf)
            v = self.to_v_lora[f](fcf)
            o = optimized_attention(q, k, v, self.heads)
            o = self.to_out_lora[f](o)
            o = self.original_module[0].to_out[1](o)
            attn_outs.append(o)

        attn_outs = torch.stack(attn_outs, dim=0)
        modified_hidden_states = modified_hidden_states + attn_outs.to(
            modified_hidden_states
        )
        modified_hidden_states = einops.rearrange(
            modified_hidden_states, "f b d c -> (b f) d c", f=self.frames
        )

        x = modified_hidden_states
        x = self.temporal_n(x)
        x = self.temporal_i(x)
        d = x.shape[1]

        x = einops.rearrange(x, "(b f) d c -> (b d) f c", f=self.frames)

        q = self.temporal_q(x)
        k = self.temporal_k(x)
        v = self.temporal_v(x)

        x = optimized_attention(q, k, v, self.heads)
        x = self.temporal_o(x)
        x = einops.rearrange(x, "(b d) f c -> (b f) d c", d=d)

        modified_hidden_states = modified_hidden_states + x

        return modified_hidden_states - h

    @classmethod
    def hijack_transformer_block(cls):
        def register_get_transformer_options(func):
            @functools.wraps(func)
            def forward(self, x, context=None, transformer_options={}):
                cls.transformer_options = transformer_options
                return func(self, x, context, transformer_options)

            return forward

        from comfy.ldm.modules.attention import BasicTransformerBlock

        BasicTransformerBlock.forward = register_get_transformer_options(
            BasicTransformerBlock.forward
        )


AttentionSharingUnit.hijack_transformer_block()


class AdditionalAttentionCondsEncoder(torch.nn.Module):
    def __init__(self):
        super().__init__()

        self.blocks_0 = torch.nn.Sequential(
            torch.nn.Conv2d(3, 32, kernel_size=3, padding=1, stride=1),
            torch.nn.SiLU(),
            torch.nn.Conv2d(32, 32, kernel_size=3, padding=1, stride=1),
            torch.nn.SiLU(),
            torch.nn.Conv2d(32, 64, kernel_size=3, padding=1, stride=2),
            torch.nn.SiLU(),
            torch.nn.Conv2d(64, 64, kernel_size=3, padding=1, stride=1),
            torch.nn.SiLU(),
            torch.nn.Conv2d(64, 128, kernel_size=3, padding=1, stride=2),
            torch.nn.SiLU(),
            torch.nn.Conv2d(128, 128, kernel_size=3, padding=1, stride=1),
            torch.nn.SiLU(),
            torch.nn.Conv2d(128, 256, kernel_size=3, padding=1, stride=2),
            torch.nn.SiLU(),
            torch.nn.Conv2d(256, 256, kernel_size=3, padding=1, stride=1),
            torch.nn.SiLU(),
        )  # 64*64*256

        self.blocks_1 = torch.nn.Sequential(
            torch.nn.Conv2d(256, 256, kernel_size=3, padding=1, stride=2),
            torch.nn.SiLU(),
            torch.nn.Conv2d(256, 256, kernel_size=3, padding=1, stride=1),
            torch.nn.SiLU(),
        )  # 32*32*256

        self.blocks_2 = torch.nn.Sequential(
            torch.nn.Conv2d(256, 256, kernel_size=3, padding=1, stride=2),
            torch.nn.SiLU(),
            torch.nn.Conv2d(256, 256, kernel_size=3, padding=1, stride=1),
            torch.nn.SiLU(),
        )  # 16*16*256

        self.blocks_3 = torch.nn.Sequential(
            torch.nn.Conv2d(256, 256, kernel_size=3, padding=1, stride=2),
            torch.nn.SiLU(),
            torch.nn.Conv2d(256, 256, kernel_size=3, padding=1, stride=1),
            torch.nn.SiLU(),
        )  # 8*8*256

        self.blks = [self.blocks_0, self.blocks_1, self.blocks_2, self.blocks_3]

    def __call__(self, h):
        results = {}
        for b in self.blks:
            h = b(h)
            results[int(h.shape[2]) * int(h.shape[3])] = h
        return results


class HookerLayers(torch.nn.Module):
    def __init__(self, layer_list):
        super().__init__()
        self.layers = torch.nn.ModuleList(layer_list)


class AttentionSharingPatcher(torch.nn.Module):
    def __init__(self, unet: ModelPatcher, frames=2, use_control=True, rank=256):
        super().__init__()

        units = []
        for i in range(32):
            key = "diffusion_model." + module_mapping_sd15[i]
            attn_module = unet.get_model_object(key)
            u = AttentionSharingUnit(
                attn_module, frames=frames, use_control=use_control, rank=rank
            )
            units.append(u)
            unet.add_object_patch(key, u)

        self.hookers = HookerLayers(units)

        if use_control:
            self.kwargs_encoder = AdditionalAttentionCondsEncoder()
        else:
            self.kwargs_encoder = None

        self.dtype = torch.float32
        if model_management.should_use_fp16(model_management.get_torch_device()):
            self.dtype = torch.float16
            self.hookers.half()
        return

    def set_control(self, img):
        img = img.cpu().float() * 2.0 - 1.0
        signals = self.kwargs_encoder(img)
        for m in self.hookers.layers:
            m.control_signals = signals
        return

Additionally Unload models when it's done using them (Optional) Edit comfyui-layerdiffuse\layered_diffusion.py


import os
from enum import Enum
import torch
import copy
from typing import Optional, List
from dataclasses import dataclass

import folder_paths
import comfy.model_management
import comfy.model_base
import comfy.supported_models
import comfy.supported_models_base
from comfy.model_patcher import ModelPatcher
from folder_paths import get_folder_paths
from comfy.utils import load_torch_file
from comfy_extras.nodes_compositing import JoinImageWithAlpha
from comfy.conds import CONDRegular
from .lib_layerdiffusion.utils import (
    load_file_from_url,
    to_lora_patch_dict,
)
from .lib_layerdiffusion.models import TransparentVAEDecoder
from .lib_layerdiffusion.attention_sharing import AttentionSharingPatcher
from .lib_layerdiffusion.enums import StableDiffusionVersion

if "layer_model" in folder_paths.folder_names_and_paths:
    layer_model_root = get_folder_paths("layer_model")[0]
else:
    layer_model_root = os.path.join(folder_paths.models_dir, "layer_model")
load_layer_model_state_dict = load_torch_file


class LayeredDiffusionDecode:
    """
    Decode alpha channel value from pixel value.
    [B, C=3, H, W] => [B, C=4, H, W]
    Outputs RGB image + Alpha mask.
    """

    @classmethod
    def INPUT_TYPES(s):
        return {
            "required": {
                "samples": ("LATENT",),
                "images": ("IMAGE",),
                "sd_version": (
                    [
                        StableDiffusionVersion.SD1x.value,
                        StableDiffusionVersion.SDXL.value,
                    ],
                    {
                        "default": StableDiffusionVersion.SDXL.value,
                    },
                ),
                "sub_batch_size": (
                    "INT",
                    {"default": 16, "min": 1, "max": 4096, "step": 1},
                ),
            },
        }

    RETURN_TYPES = ("IMAGE", "MASK")
    FUNCTION = "decode"
    CATEGORY = "layer_diffuse"

    def __init__(self) -> None:
        self.vae_transparent_decoder = {}

    def decode(self, samples, images, sd_version: str, sub_batch_size: int):
        """
        sub_batch_size: How many images to decode in a single pass.
        See https://github.com/huchenlei/ComfyUI-layerdiffuse/pull/4 for more
        context.
        """
        sd_version = StableDiffusionVersion(sd_version)
        if sd_version == StableDiffusionVersion.SD1x:
            url = "https://huggingface.co/LayerDiffusion/layerdiffusion-v1/resolve/main/layer_sd15_vae_transparent_decoder.safetensors"
            file_name = "layer_sd15_vae_transparent_decoder.safetensors"
        elif sd_version == StableDiffusionVersion.SDXL:
            url = "https://huggingface.co/LayerDiffusion/layerdiffusion-v1/resolve/main/vae_transparent_decoder.safetensors"
            file_name = "vae_transparent_decoder.safetensors"

        if not self.vae_transparent_decoder.get(sd_version):
            model_path = load_file_from_url(
                url=url, model_dir=layer_model_root, file_name=file_name
            )
            self.vae_transparent_decoder[sd_version] = TransparentVAEDecoder(
                load_torch_file(model_path),
                device=comfy.model_management.get_torch_device(),
                dtype=(
                    torch.float16
                    if comfy.model_management.should_use_fp16()
                    else torch.float32
                ),
            )
        pixel = images.movedim(-1, 1)  # [B, H, W, C] => [B, C, H, W]

        # Decoder requires dimension to be 64-aligned.
        B, C, H, W = pixel.shape
        assert H % 64 == 0, f"Height({H}) is not multiple of 64."
        assert W % 64 == 0, f"Height({W}) is not multiple of 64."

        decoded = []
        for start_idx in range(0, samples["samples"].shape[0], sub_batch_size):
            decoded.append(
                self.vae_transparent_decoder[sd_version].decode_pixel(
                    pixel[start_idx : start_idx + sub_batch_size],
                    samples["samples"][start_idx : start_idx + sub_batch_size],
                )
            )
        pixel_with_alpha = torch.cat(decoded, dim=0)

        # [B, C, H, W] => [B, H, W, C]
        pixel_with_alpha = pixel_with_alpha.movedim(1, -1)
        image = pixel_with_alpha[..., 1:]
        alpha = pixel_with_alpha[..., 0]
        return (image, alpha)


class LayeredDiffusionDecodeRGBA(LayeredDiffusionDecode):
    """
    Decode alpha channel value from pixel value.
    [B, C=3, H, W] => [B, C=4, H, W]
    Outputs RGBA image.
    """

    RETURN_TYPES = ("IMAGE",)

    def decode(self, samples, images, sd_version: str, sub_batch_size: int):
        image, mask = super().decode(samples, images, sd_version, sub_batch_size)
        alpha = 1.0 - mask
        return JoinImageWithAlpha().join_image_with_alpha(image, alpha)


class LayeredDiffusionDecodeSplit(LayeredDiffusionDecodeRGBA):
    """Decode RGBA every N images."""

    @classmethod
    def INPUT_TYPES(s):
        return {
            "required": {
                "samples": ("LATENT",),
                "images": ("IMAGE",),
                # Do RGBA decode every N output images.
                "frames": (
                    "INT",
                    {"default": 2, "min": 2, "max": s.MAX_FRAMES, "step": 1},
                ),
                "sd_version": (
                    [
                        StableDiffusionVersion.SD1x.value,
                        StableDiffusionVersion.SDXL.value,
                    ],
                    {
                        "default": StableDiffusionVersion.SDXL.value,
                    },
                ),
                "sub_batch_size": (
                    "INT",
                    {"default": 16, "min": 1, "max": 4096, "step": 1},
                ),
            },
        }

    MAX_FRAMES = 3
    RETURN_TYPES = ("IMAGE",) * MAX_FRAMES

    def decode(
        self,
        samples,
        images: torch.Tensor,
        frames: int,
        sd_version: str,
        sub_batch_size: int,
    ):
        sliced_samples = copy.copy(samples)
        sliced_samples["samples"] = sliced_samples["samples"][::frames]
        return tuple(
            (
                (
                    super(LayeredDiffusionDecodeSplit, self).decode(
                        sliced_samples, imgs, sd_version, sub_batch_size
                    )[0]
                    if i == 0
                    else imgs
                )
                for i in range(frames)
                for imgs in (images[i::frames],)
            )
        ) + (None,) * (self.MAX_FRAMES - frames)


class LayerMethod(Enum):
    ATTN = "Attention Injection"
    CONV = "Conv Injection"


class LayerType(Enum):
    FG = "Foreground"
    BG = "Background"


@dataclass
class LayeredDiffusionBase:
    model_file_name: str
    model_url: str
    sd_version: StableDiffusionVersion
    attn_sharing: bool = False
    injection_method: Optional[LayerMethod] = None
    cond_type: Optional[LayerType] = None
    # Number of output images per run.
    frames: int = 1

    @property
    def config_string(self) -> str:
        injection_method = self.injection_method.value if self.injection_method else ""
        cond_type = self.cond_type.value if self.cond_type else ""
        attn_sharing = "attn_sharing" if self.attn_sharing else ""
        frames = f"Batch size ({self.frames}N)" if self.frames != 1 else ""
        return ", ".join(
            x
            for x in (
                self.sd_version.value,
                injection_method,
                cond_type,
                attn_sharing,
                frames,
            )
            if x
        )

    def apply_c_concat(self, cond, uncond, c_concat):
        """Set foreground/background concat condition."""

        def write_c_concat(cond):
            new_cond = []
            for t in cond:
                n = [t[0], t[1].copy()]
                if "model_conds" not in n[1]:
                    n[1]["model_conds"] = {}
                n[1]["model_conds"]["c_concat"] = CONDRegular(c_concat)
                new_cond.append(n)
            return new_cond

        return (write_c_concat(cond), write_c_concat(uncond))

    def apply_layered_diffusion(
        self,
        model: ModelPatcher,
        weight: float,
    ):
        """Patch model"""
        model_path = load_file_from_url(
            url=self.model_url,
            model_dir=layer_model_root,
            file_name=self.model_file_name,
        )
        def pad_diff_weight(v):
            if len(v) == 1:
                return ("diff", [v[0], {"pad_weight": True}])
            elif len(v) == 2 and v[0] == "diff":
                return ("diff", [v[1][0], {"pad_weight": True}])
            else:
                return v

        layer_lora_state_dict = load_layer_model_state_dict(model_path)
        layer_lora_patch_dict = {
            k: pad_diff_weight(v)
            for k, v in to_lora_patch_dict(layer_lora_state_dict).items()
        }
        work_model = model.clone()
        work_model.add_patches(layer_lora_patch_dict, weight)
        return (work_model,)

    def apply_layered_diffusion_attn_sharing(
        self,
        model: ModelPatcher,
        control_img: Optional[torch.TensorType] = None,
    ):
        """Patch model with attn sharing"""
        model_path = load_file_from_url(
            url=self.model_url,
            model_dir=layer_model_root,
            file_name=self.model_file_name,
        )
        layer_lora_state_dict = load_layer_model_state_dict(model_path)
        work_model = model.clone()
        patcher = AttentionSharingPatcher(
            work_model, self.frames, use_control=control_img is not None
        )
        patcher.load_state_dict(layer_lora_state_dict, strict=True)
        if control_img is not None:
            patcher.set_control(control_img)
        return (work_model,)


def get_model_sd_version(model: ModelPatcher) -> StableDiffusionVersion:
    """Get model's StableDiffusionVersion."""
    base: comfy.model_base.BaseModel = model.model
    model_config: comfy.supported_models.supported_models_base.BASE = base.model_config
    if isinstance(model_config, comfy.supported_models.SDXL):
        return StableDiffusionVersion.SDXL
    elif isinstance(
        model_config, (comfy.supported_models.SD15, comfy.supported_models.SD20)
    ):
        # SD15 and SD20 are compatible with each other.
        return StableDiffusionVersion.SD1x
    else:
        raise Exception(f"Unsupported SD Version: {type(model_config)}.")


class LayeredDiffusionFG:
    """Generate foreground with transparent background."""

    @classmethod
    def INPUT_TYPES(s):
        return {
            "required": {
                "model": ("MODEL",),
                "config": ([c.config_string for c in s.MODELS],),
                "weight": (
                    "FLOAT",
                    {"default": 1.0, "min": -1, "max": 3, "step": 0.05},
                ),
            },
        }

    RETURN_TYPES = ("MODEL",)
    FUNCTION = "apply_layered_diffusion"
    CATEGORY = "layer_diffuse"
    MODELS = (
        LayeredDiffusionBase(
            model_file_name="layer_xl_transparent_attn.safetensors",
            model_url="https://huggingface.co/LayerDiffusion/layerdiffusion-v1/resolve/main/layer_xl_transparent_attn.safetensors",
            sd_version=StableDiffusionVersion.SDXL,
            injection_method=LayerMethod.ATTN,
        ),
        LayeredDiffusionBase(
            model_file_name="layer_xl_transparent_conv.safetensors",
            model_url="https://huggingface.co/LayerDiffusion/layerdiffusion-v1/resolve/main/layer_xl_transparent_conv.safetensors",
            sd_version=StableDiffusionVersion.SDXL,
            injection_method=LayerMethod.CONV,
        ),
        LayeredDiffusionBase(
            model_file_name="layer_sd15_transparent_attn.safetensors",
            model_url="https://huggingface.co/LayerDiffusion/layerdiffusion-v1/resolve/main/layer_sd15_transparent_attn.safetensors",
            sd_version=StableDiffusionVersion.SD1x,
            injection_method=LayerMethod.ATTN,
            attn_sharing=True,
        ),
    )

    def apply_layered_diffusion(
        self,
        model: ModelPatcher,
        config: str,
        weight: float,
    ):
        ld_model = [m for m in self.MODELS if m.config_string == config][0]
        assert get_model_sd_version(model) == ld_model.sd_version
        if ld_model.attn_sharing:
            return ld_model.apply_layered_diffusion_attn_sharing(model)
        else:
            work_model = ld_model.apply_layered_diffusion(model, weight)
            comfy.model_management.unload_model()
            return work_model


class LayeredDiffusionJoint:
    """Generate FG + BG + Blended in one inference batch. Batch size = 3N."""

    @classmethod
    def INPUT_TYPES(s):
        return {
            "required": {
                "model": ("MODEL",),
                "config": ([c.config_string for c in s.MODELS],),
            },
            "optional": {
                "fg_cond": ("CONDITIONING",),
                "bg_cond": ("CONDITIONING",),
                "blended_cond": ("CONDITIONING",),
            },
        }

    RETURN_TYPES = ("MODEL",)
    FUNCTION = "apply_layered_diffusion"
    CATEGORY = "layer_diffuse"
    MODELS = (
        LayeredDiffusionBase(
            model_file_name="layer_sd15_joint.safetensors",
            model_url="https://huggingface.co/LayerDiffusion/layerdiffusion-v1/resolve/main/layer_sd15_joint.safetensors",
            sd_version=StableDiffusionVersion.SD1x,
            attn_sharing=True,
            frames=3,
        ),
    )

    def apply_layered_diffusion(
        self,
        model: ModelPatcher,
        config: str,
        fg_cond: Optional[List[List[torch.TensorType]]] = None,
        bg_cond: Optional[List[List[torch.TensorType]]] = None,
        blended_cond: Optional[List[List[torch.TensorType]]] = None,
    ):
        ld_model = [m for m in self.MODELS if m.config_string == config][0]
        assert get_model_sd_version(model) == ld_model.sd_version
        assert ld_model.attn_sharing
        work_model = ld_model.apply_layered_diffusion_attn_sharing(model)[0]
        work_model.model_options.setdefault("transformer_options", {})
        work_model.model_options["transformer_options"]["cond_overwrite"] = [
            cond[0][0] if cond is not None else None
            for cond in (
                fg_cond,
                bg_cond,
                blended_cond,
            )
        ]
        comfy.model_management.unload_model()
        return (work_model,)


class LayeredDiffusionCond:
    """Generate foreground + background given background / foreground.
    - FG => Blended
    - BG => Blended
    """

    @classmethod
    def INPUT_TYPES(s):
        return {
            "required": {
                "model": ("MODEL",),
                "cond": ("CONDITIONING",),
                "uncond": ("CONDITIONING",),
                "latent": ("LATENT",),
                "config": ([c.config_string for c in s.MODELS],),
                "weight": (
                    "FLOAT",
                    {"default": 1.0, "min": -1, "max": 3, "step": 0.05},
                ),
            },
        }

    RETURN_TYPES = ("MODEL", "CONDITIONING", "CONDITIONING")
    FUNCTION = "apply_layered_diffusion"
    CATEGORY = "layer_diffuse"
    MODELS = (
        LayeredDiffusionBase(
            model_file_name="layer_xl_fg2ble.safetensors",
            model_url="https://huggingface.co/LayerDiffusion/layerdiffusion-v1/resolve/main/layer_xl_fg2ble.safetensors",
            sd_version=StableDiffusionVersion.SDXL,
            cond_type=LayerType.FG,
        ),
        LayeredDiffusionBase(
            model_file_name="layer_xl_bg2ble.safetensors",
            model_url="https://huggingface.co/LayerDiffusion/layerdiffusion-v1/resolve/main/layer_xl_bg2ble.safetensors",
            sd_version=StableDiffusionVersion.SDXL,
            cond_type=LayerType.BG,
        ),
    )

    def apply_layered_diffusion(
        self,
        model: ModelPatcher,
        cond,
        uncond,
        latent,
        config: str,
        weight: float,
    ):
        ld_model = [m for m in self.MODELS if m.config_string == config][0]
        assert get_model_sd_version(model) == ld_model.sd_version
        c_concat = model.model.latent_format.process_in(latent["samples"])
        return ld_model.apply_layered_diffusion(
            model, weight
        ) + ld_model.apply_c_concat(cond, uncond, c_concat)
        comfy.model_management.unload_model()


class LayeredDiffusionCondJoint:
    """Generate fg/bg + blended given fg/bg.
    - FG => Blended + BG
    - BG => Blended + FG
    """

    @classmethod
    def INPUT_TYPES(s):
        return {
            "required": {
                "model": ("MODEL",),
                "image": ("IMAGE",),
                "config": ([c.config_string for c in s.MODELS],),
            },
            "optional": {
                "cond": ("CONDITIONING",),
                "blended_cond": ("CONDITIONING",),
            },
        }

    RETURN_TYPES = ("MODEL",)
    FUNCTION = "apply_layered_diffusion"
    CATEGORY = "layer_diffuse"
    MODELS = (
        LayeredDiffusionBase(
            model_file_name="layer_sd15_fg2bg.safetensors",
            model_url="https://huggingface.co/LayerDiffusion/layerdiffusion-v1/resolve/main/layer_sd15_fg2bg.safetensors",
            sd_version=StableDiffusionVersion.SD1x,
            attn_sharing=True,
            frames=2,
            cond_type=LayerType.FG,
        ),
        LayeredDiffusionBase(
            model_file_name="layer_sd15_bg2fg.safetensors",
            model_url="https://huggingface.co/LayerDiffusion/layerdiffusion-v1/resolve/main/layer_sd15_bg2fg.safetensors",
            sd_version=StableDiffusionVersion.SD1x,
            attn_sharing=True,
            frames=2,
            cond_type=LayerType.BG,
        ),
    )

    def apply_layered_diffusion(
        self,
        model: ModelPatcher,
        image,
        config: str,
        cond: Optional[List[List[torch.TensorType]]] = None,
        blended_cond: Optional[List[List[torch.TensorType]]] = None,
    ):
        ld_model = [m for m in self.MODELS if m.config_string == config][0]
        assert get_model_sd_version(model) == ld_model.sd_version
        assert ld_model.attn_sharing
        work_model = ld_model.apply_layered_diffusion_attn_sharing(
            model, control_img=image.movedim(-1, 1)
        )[0]
        work_model.model_options.setdefault("transformer_options", {})
        work_model.model_options["transformer_options"]["cond_overwrite"] = [
            cond[0][0] if cond is not None else None
            for cond in (
                cond,
                blended_cond,
            )
        ]
        comfy.model_management.unload_model()
        return (work_model,)


class LayeredDiffusionDiff:
    """Extract FG/BG from blended image.
    - Blended + FG => BG
    - Blended + BG => FG
    """

    @classmethod
    def INPUT_TYPES(s):
        return {
            "required": {
                "model": ("MODEL",),
                "cond": ("CONDITIONING",),
                "uncond": ("CONDITIONING",),
                "blended_latent": ("LATENT",),
                "latent": ("LATENT",),
                "config": ([c.config_string for c in s.MODELS],),
                "weight": (
                    "FLOAT",
                    {"default": 1.0, "min": -1, "max": 3, "step": 0.05},
                ),
            },
        }

    RETURN_TYPES = ("MODEL", "CONDITIONING", "CONDITIONING")
    FUNCTION = "apply_layered_diffusion"
    CATEGORY = "layer_diffuse"
    MODELS = (
        LayeredDiffusionBase(
            model_file_name="layer_xl_fgble2bg.safetensors",
            model_url="https://huggingface.co/LayerDiffusion/layerdiffusion-v1/resolve/main/layer_xl_fgble2bg.safetensors",
            sd_version=StableDiffusionVersion.SDXL,
            cond_type=LayerType.FG,
        ),
        LayeredDiffusionBase(
            model_file_name="layer_xl_bgble2fg.safetensors",
            model_url="https://huggingface.co/LayerDiffusion/layerdiffusion-v1/resolve/main/layer_xl_bgble2fg.safetensors",
            sd_version=StableDiffusionVersion.SDXL,
            cond_type=LayerType.BG,
        ),
    )

    def apply_layered_diffusion(
        self,
        model: ModelPatcher,
        cond,
        uncond,
        blended_latent,
        latent,
        config: str,
        weight: float,
    ):
        ld_model = [m for m in self.MODELS if m.config_string == config][0]
        assert get_model_sd_version(model) == ld_model.sd_version
        c_concat = model.model.latent_format.process_in(
            torch.cat([latent["samples"], blended_latent["samples"]], dim=1)
        )
        return ld_model.apply_layered_diffusion(
            model, weight
        ) + ld_model.apply_c_concat(cond, uncond, c_concat)
        comfy.model_management.unload_model()


NODE_CLASS_MAPPINGS = {
    "LayeredDiffusionApply": LayeredDiffusionFG,
    "LayeredDiffusionJointApply": LayeredDiffusionJoint,
    "LayeredDiffusionCondApply": LayeredDiffusionCond,
    "LayeredDiffusionCondJointApply": LayeredDiffusionCondJoint,
    "LayeredDiffusionDiffApply": LayeredDiffusionDiff,
    "LayeredDiffusionDecode": LayeredDiffusionDecode,
    "LayeredDiffusionDecodeRGBA": LayeredDiffusionDecodeRGBA,
    "LayeredDiffusionDecodeSplit": LayeredDiffusionDecodeSplit,
}

NODE_DISPLAY_NAME_MAPPINGS = {
    "LayeredDiffusionApply": "Layer Diffuse Apply",
    "LayeredDiffusionJointApply": "Layer Diffuse Joint Apply",
    "LayeredDiffusionCondApply": "Layer Diffuse Cond Apply",
    "LayeredDiffusionCondJointApply": "Layer Diffuse Cond Joint Apply",
    "LayeredDiffusionDiffApply": "Layer Diffuse Diff Apply",
    "LayeredDiffusionDecode": "Layer Diffuse Decode",
    "LayeredDiffusionDecodeRGBA": "Layer Diffuse Decode (RGBA)",
    "LayeredDiffusionDecodeSplit": "Layer Diffuse Decode (Split)",
}

jewelsonn avatar Mar 06 '25 20:03 jewelsonn