PuLID icon indicating copy to clipboard operation
PuLID copied to clipboard

How many steps do stage2 and stage3 trained?

Open xilanhua12138 opened this issue 1 year ago • 14 comments

xilanhua12138 avatar Nov 22 '24 06:11 xilanhua12138

@zsxkib @guozinan126 @ToTheBeginning

xilanhua12138 avatar Nov 25 '24 12:11 xilanhua12138

For the XL model, each of stages 2 and 3 will not exceed 30,000 iterations. We mentioned in the paper that using only ID loss can significantly improve similarity, especially in terms of quantitative metrics, but it can lead to a reward-hacking issue, so the image quality will degrade more seriously after training for a long time.

ToTheBeginning avatar Nov 27 '24 12:11 ToTheBeginning

I am also trying to train the model, but I have some confusion regarding the loss design in the second stage. According to the paper, the lighting branch generates a prediction image from pure noise in 4 steps and calculates the ID loss(cosine similarity) with the original image. This process is not influenced by the diffusion loss branch's tempstep, but rather it is affected only by the most recently updated pulid encoder and pulid_ca models. By adding these two losses together for backpropagation, does the ID loss indirectly affect the gradient updates? Currently, I am quite confused as the model is converging very slowly.@ToTheBeginning @zsxkib @guozinan126

LeonNerd avatar Feb 18 '25 09:02 LeonNerd

I’m not entirely sure, but one thing you can try is purposefully overfitting on a small batch of data and checking if the convergence behavior matches expectations. This is a common trick in deep learning—if the model can’t overfit a tiny dataset, there’s probably something wrong with the setup. You might want to test this and see if it gives any insight into why convergence is slow

zsxkib avatar Feb 19 '25 10:02 zsxkib

No description provided.

请问有解决这个问题么?很奇怪我的第二阶段模型不收敛。我是在sd3尝试的方法,第二阶段我在原有的基础上加入了hypersd,这里从纯噪声开始分4步生成图像。图像生成过程中IDformer和crosstention模块都有梯度,vae解码张量也会保留梯度,后续的arcface特征提取也会有梯度。IDformer和crosstention模块是嵌入在hypersd中的,所以我的id loss会在hypersd分支中用梯度更新IDformer和crosstention模块

LeonNerd avatar Feb 28 '25 02:02 LeonNerd

No description provided.

请问有解决这个问题么?很奇怪我的第二阶段模型不收敛。我是在sd3尝试的方法,第二阶段我在原有的基础上加入了hypersd,这里从纯噪声开始分4步生成图像。图像生成过程中IDformer和crosstention模块都有梯度,vae解码张量也会保留梯度,后续的arcface特征提取也会有梯度。IDformer和crosstention模块是嵌入在hypersd中的,所以我的id loss会在hypersd分支中用梯度更新IDformer和crosstention模块

想请问一下这篇论文在训练的时候是只通过SDXL_lightning训练是吗,并不是SDXL和SDXL_lightning联合训练,因为我对于论文的多分支训练有些搞不清楚,所以问一问尝试复现的人

JamieCR1999 avatar Jun 04 '25 08:06 JamieCR1999

No description provided.

请问有解决这个问题么?很奇怪我的第二阶段模型不收敛。我是在sd3尝试的方法,第二阶段我在原有的基础上加入了hypersd,这里从纯噪声开始分4步生成图像。图像生成过程中IDformer和crosstention模块都有梯度,vae解码张量也会保留梯度,后续的arcface特征提取也会有梯度。IDformer和crosstention模块是嵌入在hypersd中的,所以我的id loss会在hypersd分支中用梯度更新IDformer和crosstention模块

想请问一下这篇论文在训练的时候是只通过SDXL_lightning训练是吗,并不是SDXL和SDXL_lightning联合训练,因为我对于论文的多分支训练有些搞不清楚,所以问一问尝试复现的人

可以理解为lightning是一个插件,只是在算id loss的时候加进来,diffusion loss我还是在原来的模型上算的。我发一部分代码可以参考一下

                model_pred = model(
                    x_t=x_t,
                    t=t,
                    clip_images=batch['clip_images'], 
                    prompt_embeds=prompt_embeds,
                    pooled_prompt_embeds=pooled_prompt_embeds,
                    text_ids=text_ids,
                    latent_image_ids=latent_image_ids,
                    guidance_vec=guidance_vec,
                    arch=arch, 
                    accelerator=accelerator, 
                    weight_dtype=weight_dtype,
                    args=args,
                    drop_image_embeds=batch["drop_image_embeds"],
                    face_embeddings=batch["face_embeddings"],
                    clip_images_ori=batch["clip_images_ori"],
                    ip_scale=ip_scale,
                    lora_scale=0.0,
                )
 
                loss = F.mse_loss(model_pred.float(), (x_0 - x_1).float(), reduction="mean")

                seed = random.randint(0, 2**32 - 1)
                bp_ts = args.bp_ts
                num_inference_steps = 8
                sample_method = 'random'
                if sample_method == 'random':
                    bp_idx = random.sample(range(num_inference_steps - bp_ts), bp_ts)
                else:  # 'last'
                    bp_idx = [num_inference_steps - bp_ts]
                # Sample time steps for inference
                with torch.inference_mode():
                    KQQs_t2i_temp, Qs_t2i_temp, image_t2i = align_t2i(        
                        args,
                        model, 
                        batch,
                        x_t,
                        prompt_embeds,
                        pooled_prompt_embeds,
                        text_ids,
                        arch,
                        arch_body,
                        body_vae_embeds,
                        subject_embeddings,
                        body_clip_images,
                        accelerator,
                        noise_scheduler,
                        taef1,
                        weight_dtype,
                        ip_scale=0.0,
                        seed=seed,
                        out_image=True,
                        bp_ts=bp_ts,
                        bp_idx=bp_idx,
                        lora_scale=1.0,
                        num_inference_steps=num_inference_steps,
                        image_height=args.resolution[0],
                        image_width=args.resolution[1],
                    )
                KQQs_t2i = [{k: v.detach().clone() for k, v in kqq.items()} for kqq in KQQs_t2i_temp]
                Qs_t2i = [{k: v.detach().clone() for k, v in q.items()} for q in Qs_t2i_temp]
                KQQs_ipa, Qs_ipa, image_ipa = align_t2i(        
                    args,
                    model, 
                    batch,
                    x_t,
                    prompt_embeds,
                    pooled_prompt_embeds,
                    text_ids,
                    arch,
                    arch_body,
                    body_vae_embeds,
                    subject_embeddings,
                    body_clip_images,
                    accelerator,
                    noise_scheduler,
                    taef1,
                    weight_dtype,
                    ip_scale=1.0,
                    seed=seed,
                    out_image=True,
                    bp_ts=bp_ts,
                    bp_idx=bp_idx,
                    lora_scale=1.0,
                    num_inference_steps=num_inference_steps,
                    image_height=args.resolution[0],
                    image_width=args.resolution[1],
                )

                loss_align_semantic = 0.0
                loss_align_layout = 0.0
                for step in range(bp_ts):
                    loss_align_semantic += calcu_loss_align(KQQs_ipa[step], KQQs_t2i[step])
                    loss_align_layout += calcu_loss_align(Qs_t2i[step], Qs_ipa[step])
                loss_align_semantic = loss_align_semantic / bp_ts
                loss_align_layout = loss_align_layout / bp_ts

                loss_dict = criterion_id.forward(image_ipa, pixel_values)
                loss_face = loss_dict['loss_id']

                loss = loss + args.id_loss_weight * loss_face + args.look_align_layout_weight * loss_align_layout + args.look_align_semantic_weight * loss_align_semantic

def align_t2i(
        args,
        model, 
        batch,
        x_t,
        prompt_embeds,
        pooled_prompt_embeds,
        text_ids,
        arch,
        arch_body,
        body_vae_embeds,
        subject_embeddings,
        body_clip_images,
        accelerator,
        noise_scheduler,
        taef1,
        weight_dtype,
        ip_scale,
        seed,
        num_inference_steps=8,
        bp_ts=1,
        out_image=True,
        store_attention_maps='align',
        lora_scale=1.0,
        bp_idx=None,
        image_height=512,
        image_width=512,
        ):
    # Sample time steps for inference
    generator = torch.Generator(device=accelerator.device).manual_seed(seed)
    bs = x_t.shape[0]
    # height = 2 * math.ceil(batch['images'].shape[2] // 16)
    # width = 2 * math.ceil(batch['images'].shape[3] // 16)
    height = 2 * image_height // 16
    width = 2 * image_width // 16
    latents = torch.randn(
        bs,
        16,
        height,
        width,
        device=accelerator.device,
        dtype=x_t.dtype,
        generator=generator,
    )
    latents = rearrange(latents, "b c (h ph) (w pw) -> b (h w) (c ph pw)", ph=2, pw=2)
    latent_image_ids = FluxPipeline._prepare_latent_image_ids(bs, height // 2, width // 2, accelerator.device, x_t.dtype)
    sigmas = np.linspace(1.0, 1 / num_inference_steps, num_inference_steps)
    image_seq_len = latents.shape[1]
    mu = calculate_shift(
        image_seq_len,
        noise_scheduler.config.base_image_seq_len,
        noise_scheduler.config.max_image_seq_len,
        noise_scheduler.config.base_shift,
        noise_scheduler.config.max_shift,
    )
    timesteps_inference, num_inference_steps = retrieve_timesteps(
        noise_scheduler,
        num_inference_steps,
        accelerator.device,
        None,
        sigmas,
        mu=mu,
    )
    guidance = torch.full([1], 4, device=accelerator.device, dtype=torch.float32)
    guidance = guidance.expand(x_t.shape[0])
    
    KQQs = []
    Qs = []
    for i, t in enumerate(timesteps_inference):
        if i in bp_idx:
            context = nullcontext()
            store_attention_maps = 'align'
        else:
            context = torch.inference_mode()
            store_attention_maps = None

        with context:
            timestep = t.expand(latents.shape[0]).to(latents.dtype)
            noise_pred = model(
                x_t=latents,
                t=timestep/1000,
                clip_images=batch['clip_images'], 
                body_clip_images=body_clip_images, 
                prompt_embeds=prompt_embeds,
                pooled_prompt_embeds=pooled_prompt_embeds,
                text_ids=text_ids,
                latent_image_ids=latent_image_ids,
                guidance_vec=guidance,
                arch=arch, 
                arch_body=arch_body,
                vae_embeds=body_vae_embeds,
                accelerator=accelerator, 
                weight_dtype=weight_dtype,
                args=args,
                drop_image_embeds=batch["drop_image_embeds"],
                face_embeddings=batch["face_embeddings"],
                clip_images_ori=batch["clip_images_ori"],
                subject_embedding=subject_embeddings,
                ip_scale=ip_scale,
                store_attention_maps = store_attention_maps,
                lora_scale=lora_scale,
            )
            latents = noise_scheduler.step(noise_pred, t, latents, return_dict=False)[0]
            
            if i in bp_idx:
                current_KQQs = {k: v for k, v in model.KQQs.items()}
                current_Qs = {k: v for k, v in model.Qs.items()}
                KQQs.append(current_KQQs)
                Qs.append(current_Qs)

    if out_image:
        latents = rearrange(
            latents,
            "b (h w) (c ph pw) -> b c (h ph) (w pw)",
            h=image_height // 16,
            w=image_width // 16,
            ph=2,
            pw=2,
        )
        latents = (latents / taef1.config.scaling_factor) + taef1.config.shift_factor
        image = taef1.decode(latents.to(weight_dtype), return_dict=False)[0]
    else:
        image = None
    return KQQs, Qs, image

xilanhua12138 avatar Jun 04 '25 09:06 xilanhua12138

No description provided.

请问有解决这个问题么?很奇怪我的第二阶段模型不收敛。我是在sd3尝试的方法,第二阶段我在原有的基础上加入了hypersd,这里从纯噪声开始分4步生成图像。图像生成过程中IDformer和crosstention模块都有梯度,vae解码张量也会保留梯度,后续的arcface特征提取也会有梯度。IDformer和crosstention模块是嵌入在hypersd中的,所以我的id loss会在hypersd分支中用梯度更新IDformer和crosstention模块

想请问一下这篇论文在训练的时候是只通过SDXL_lightning训练是吗,并不是SDXL和SDXL_lightning联合训练,因为我对于论文的多分支训练有些搞不清楚,所以问一问尝试复现的人

可以理解为lightning是一个插件,只是在算id loss的时候加进来,diffusion loss我还是在原来的模型上算的。我发一部分代码可以参考一下

                model_pred = model(
                    x_t=x_t,
                    t=t,
                    clip_images=batch['clip_images'], 
                    prompt_embeds=prompt_embeds,
                    pooled_prompt_embeds=pooled_prompt_embeds,
                    text_ids=text_ids,
                    latent_image_ids=latent_image_ids,
                    guidance_vec=guidance_vec,
                    arch=arch, 
                    accelerator=accelerator, 
                    weight_dtype=weight_dtype,
                    args=args,
                    drop_image_embeds=batch["drop_image_embeds"],
                    face_embeddings=batch["face_embeddings"],
                    clip_images_ori=batch["clip_images_ori"],
                    ip_scale=ip_scale,
                    lora_scale=0.0,
                )
 
                loss = F.mse_loss(model_pred.float(), (x_0 - x_1).float(), reduction="mean")

                seed = random.randint(0, 2**32 - 1)
                bp_ts = args.bp_ts
                num_inference_steps = 8
                sample_method = 'random'
                if sample_method == 'random':
                    bp_idx = random.sample(range(num_inference_steps - bp_ts), bp_ts)
                else:  # 'last'
                    bp_idx = [num_inference_steps - bp_ts]
                # Sample time steps for inference
                with torch.inference_mode():
                    KQQs_t2i_temp, Qs_t2i_temp, image_t2i = align_t2i(        
                        args,
                        model, 
                        batch,
                        x_t,
                        prompt_embeds,
                        pooled_prompt_embeds,
                        text_ids,
                        arch,
                        arch_body,
                        body_vae_embeds,
                        subject_embeddings,
                        body_clip_images,
                        accelerator,
                        noise_scheduler,
                        taef1,
                        weight_dtype,
                        ip_scale=0.0,
                        seed=seed,
                        out_image=True,
                        bp_ts=bp_ts,
                        bp_idx=bp_idx,
                        lora_scale=1.0,
                        num_inference_steps=num_inference_steps,
                        image_height=args.resolution[0],
                        image_width=args.resolution[1],
                    )
                KQQs_t2i = [{k: v.detach().clone() for k, v in kqq.items()} for kqq in KQQs_t2i_temp]
                Qs_t2i = [{k: v.detach().clone() for k, v in q.items()} for q in Qs_t2i_temp]
                KQQs_ipa, Qs_ipa, image_ipa = align_t2i(        
                    args,
                    model, 
                    batch,
                    x_t,
                    prompt_embeds,
                    pooled_prompt_embeds,
                    text_ids,
                    arch,
                    arch_body,
                    body_vae_embeds,
                    subject_embeddings,
                    body_clip_images,
                    accelerator,
                    noise_scheduler,
                    taef1,
                    weight_dtype,
                    ip_scale=1.0,
                    seed=seed,
                    out_image=True,
                    bp_ts=bp_ts,
                    bp_idx=bp_idx,
                    lora_scale=1.0,
                    num_inference_steps=num_inference_steps,
                    image_height=args.resolution[0],
                    image_width=args.resolution[1],
                )

                loss_align_semantic = 0.0
                loss_align_layout = 0.0
                for step in range(bp_ts):
                    loss_align_semantic += calcu_loss_align(KQQs_ipa[step], KQQs_t2i[step])
                    loss_align_layout += calcu_loss_align(Qs_t2i[step], Qs_ipa[step])
                loss_align_semantic = loss_align_semantic / bp_ts
                loss_align_layout = loss_align_layout / bp_ts

                loss_dict = criterion_id.forward(image_ipa, pixel_values)
                loss_face = loss_dict['loss_id']

                loss = loss + args.id_loss_weight * loss_face + args.look_align_layout_weight * loss_align_layout + args.look_align_semantic_weight * loss_align_semantic

def align_t2i(
        args,
        model, 
        batch,
        x_t,
        prompt_embeds,
        pooled_prompt_embeds,
        text_ids,
        arch,
        arch_body,
        body_vae_embeds,
        subject_embeddings,
        body_clip_images,
        accelerator,
        noise_scheduler,
        taef1,
        weight_dtype,
        ip_scale,
        seed,
        num_inference_steps=8,
        bp_ts=1,
        out_image=True,
        store_attention_maps='align',
        lora_scale=1.0,
        bp_idx=None,
        image_height=512,
        image_width=512,
        ):
    # Sample time steps for inference
    generator = torch.Generator(device=accelerator.device).manual_seed(seed)
    bs = x_t.shape[0]
    # height = 2 * math.ceil(batch['images'].shape[2] // 16)
    # width = 2 * math.ceil(batch['images'].shape[3] // 16)
    height = 2 * image_height // 16
    width = 2 * image_width // 16
    latents = torch.randn(
        bs,
        16,
        height,
        width,
        device=accelerator.device,
        dtype=x_t.dtype,
        generator=generator,
    )
    latents = rearrange(latents, "b c (h ph) (w pw) -> b (h w) (c ph pw)", ph=2, pw=2)
    latent_image_ids = FluxPipeline._prepare_latent_image_ids(bs, height // 2, width // 2, accelerator.device, x_t.dtype)
    sigmas = np.linspace(1.0, 1 / num_inference_steps, num_inference_steps)
    image_seq_len = latents.shape[1]
    mu = calculate_shift(
        image_seq_len,
        noise_scheduler.config.base_image_seq_len,
        noise_scheduler.config.max_image_seq_len,
        noise_scheduler.config.base_shift,
        noise_scheduler.config.max_shift,
    )
    timesteps_inference, num_inference_steps = retrieve_timesteps(
        noise_scheduler,
        num_inference_steps,
        accelerator.device,
        None,
        sigmas,
        mu=mu,
    )
    guidance = torch.full([1], 4, device=accelerator.device, dtype=torch.float32)
    guidance = guidance.expand(x_t.shape[0])
    
    KQQs = []
    Qs = []
    for i, t in enumerate(timesteps_inference):
        if i in bp_idx:
            context = nullcontext()
            store_attention_maps = 'align'
        else:
            context = torch.inference_mode()
            store_attention_maps = None

        with context:
            timestep = t.expand(latents.shape[0]).to(latents.dtype)
            noise_pred = model(
                x_t=latents,
                t=timestep/1000,
                clip_images=batch['clip_images'], 
                body_clip_images=body_clip_images, 
                prompt_embeds=prompt_embeds,
                pooled_prompt_embeds=pooled_prompt_embeds,
                text_ids=text_ids,
                latent_image_ids=latent_image_ids,
                guidance_vec=guidance,
                arch=arch, 
                arch_body=arch_body,
                vae_embeds=body_vae_embeds,
                accelerator=accelerator, 
                weight_dtype=weight_dtype,
                args=args,
                drop_image_embeds=batch["drop_image_embeds"],
                face_embeddings=batch["face_embeddings"],
                clip_images_ori=batch["clip_images_ori"],
                subject_embedding=subject_embeddings,
                ip_scale=ip_scale,
                store_attention_maps = store_attention_maps,
                lora_scale=lora_scale,
            )
            latents = noise_scheduler.step(noise_pred, t, latents, return_dict=False)[0]
            
            if i in bp_idx:
                current_KQQs = {k: v for k, v in model.KQQs.items()}
                current_Qs = {k: v for k, v in model.Qs.items()}
                KQQs.append(current_KQQs)
                Qs.append(current_Qs)

    if out_image:
        latents = rearrange(
            latents,
            "b (h w) (c ph pw) -> b c (h ph) (w pw)",
            h=image_height // 16,
            w=image_width // 16,
            ph=2,
            pw=2,
        )
        latents = (latents / taef1.config.scaling_factor) + taef1.config.shift_factor
        image = taef1.decode(latents.to(weight_dtype), return_dict=False)[0]
    else:
        image = None
    return KQQs, Qs, image

感谢!我去学习一下!

JamieCR1999 avatar Jun 04 '25 09:06 JamieCR1999

No description provided.

请问有解决这个问题么?很奇怪我的第二阶段模型不收敛。我是在sd3尝试的方法,第二阶段我在原有的基础上加入了hypersd,这里从纯噪声开始分4步生成图像。图像生成过程中IDformer和crosstention模块都有梯度,vae解码张量也会保留梯度,后续的arcface特征提取也会有梯度。IDformer和crosstention模块是嵌入在hypersd中的,所以我的id loss会在hypersd分支中用梯度更新IDformer和crosstention模块

想请问一下这篇论文在训练的时候是只通过SDXL_lightning训练是吗,并不是SDXL和SDXL_lightning联合训练,因为我对于论文的多分支训练有些搞不清楚,所以问一问尝试复现的人

可以理解为lightning是一个插件,只是在算id loss的时候加进来,diffusion loss我还是在原来的模型上算的。我发一部分代码可以参考一下

                model_pred = model(
                    x_t=x_t,
                    t=t,
                    clip_images=batch['clip_images'], 
                    prompt_embeds=prompt_embeds,
                    pooled_prompt_embeds=pooled_prompt_embeds,
                    text_ids=text_ids,
                    latent_image_ids=latent_image_ids,
                    guidance_vec=guidance_vec,
                    arch=arch, 
                    accelerator=accelerator, 
                    weight_dtype=weight_dtype,
                    args=args,
                    drop_image_embeds=batch["drop_image_embeds"],
                    face_embeddings=batch["face_embeddings"],
                    clip_images_ori=batch["clip_images_ori"],
                    ip_scale=ip_scale,
                    lora_scale=0.0,
                )
 
                loss = F.mse_loss(model_pred.float(), (x_0 - x_1).float(), reduction="mean")

                seed = random.randint(0, 2**32 - 1)
                bp_ts = args.bp_ts
                num_inference_steps = 8
                sample_method = 'random'
                if sample_method == 'random':
                    bp_idx = random.sample(range(num_inference_steps - bp_ts), bp_ts)
                else:  # 'last'
                    bp_idx = [num_inference_steps - bp_ts]
                # Sample time steps for inference
                with torch.inference_mode():
                    KQQs_t2i_temp, Qs_t2i_temp, image_t2i = align_t2i(        
                        args,
                        model, 
                        batch,
                        x_t,
                        prompt_embeds,
                        pooled_prompt_embeds,
                        text_ids,
                        arch,
                        arch_body,
                        body_vae_embeds,
                        subject_embeddings,
                        body_clip_images,
                        accelerator,
                        noise_scheduler,
                        taef1,
                        weight_dtype,
                        ip_scale=0.0,
                        seed=seed,
                        out_image=True,
                        bp_ts=bp_ts,
                        bp_idx=bp_idx,
                        lora_scale=1.0,
                        num_inference_steps=num_inference_steps,
                        image_height=args.resolution[0],
                        image_width=args.resolution[1],
                    )
                KQQs_t2i = [{k: v.detach().clone() for k, v in kqq.items()} for kqq in KQQs_t2i_temp]
                Qs_t2i = [{k: v.detach().clone() for k, v in q.items()} for q in Qs_t2i_temp]
                KQQs_ipa, Qs_ipa, image_ipa = align_t2i(        
                    args,
                    model, 
                    batch,
                    x_t,
                    prompt_embeds,
                    pooled_prompt_embeds,
                    text_ids,
                    arch,
                    arch_body,
                    body_vae_embeds,
                    subject_embeddings,
                    body_clip_images,
                    accelerator,
                    noise_scheduler,
                    taef1,
                    weight_dtype,
                    ip_scale=1.0,
                    seed=seed,
                    out_image=True,
                    bp_ts=bp_ts,
                    bp_idx=bp_idx,
                    lora_scale=1.0,
                    num_inference_steps=num_inference_steps,
                    image_height=args.resolution[0],
                    image_width=args.resolution[1],
                )

                loss_align_semantic = 0.0
                loss_align_layout = 0.0
                for step in range(bp_ts):
                    loss_align_semantic += calcu_loss_align(KQQs_ipa[step], KQQs_t2i[step])
                    loss_align_layout += calcu_loss_align(Qs_t2i[step], Qs_ipa[step])
                loss_align_semantic = loss_align_semantic / bp_ts
                loss_align_layout = loss_align_layout / bp_ts

                loss_dict = criterion_id.forward(image_ipa, pixel_values)
                loss_face = loss_dict['loss_id']

                loss = loss + args.id_loss_weight * loss_face + args.look_align_layout_weight * loss_align_layout + args.look_align_semantic_weight * loss_align_semantic

def align_t2i(
        args,
        model, 
        batch,
        x_t,
        prompt_embeds,
        pooled_prompt_embeds,
        text_ids,
        arch,
        arch_body,
        body_vae_embeds,
        subject_embeddings,
        body_clip_images,
        accelerator,
        noise_scheduler,
        taef1,
        weight_dtype,
        ip_scale,
        seed,
        num_inference_steps=8,
        bp_ts=1,
        out_image=True,
        store_attention_maps='align',
        lora_scale=1.0,
        bp_idx=None,
        image_height=512,
        image_width=512,
        ):
    # Sample time steps for inference
    generator = torch.Generator(device=accelerator.device).manual_seed(seed)
    bs = x_t.shape[0]
    # height = 2 * math.ceil(batch['images'].shape[2] // 16)
    # width = 2 * math.ceil(batch['images'].shape[3] // 16)
    height = 2 * image_height // 16
    width = 2 * image_width // 16
    latents = torch.randn(
        bs,
        16,
        height,
        width,
        device=accelerator.device,
        dtype=x_t.dtype,
        generator=generator,
    )
    latents = rearrange(latents, "b c (h ph) (w pw) -> b (h w) (c ph pw)", ph=2, pw=2)
    latent_image_ids = FluxPipeline._prepare_latent_image_ids(bs, height // 2, width // 2, accelerator.device, x_t.dtype)
    sigmas = np.linspace(1.0, 1 / num_inference_steps, num_inference_steps)
    image_seq_len = latents.shape[1]
    mu = calculate_shift(
        image_seq_len,
        noise_scheduler.config.base_image_seq_len,
        noise_scheduler.config.max_image_seq_len,
        noise_scheduler.config.base_shift,
        noise_scheduler.config.max_shift,
    )
    timesteps_inference, num_inference_steps = retrieve_timesteps(
        noise_scheduler,
        num_inference_steps,
        accelerator.device,
        None,
        sigmas,
        mu=mu,
    )
    guidance = torch.full([1], 4, device=accelerator.device, dtype=torch.float32)
    guidance = guidance.expand(x_t.shape[0])
    
    KQQs = []
    Qs = []
    for i, t in enumerate(timesteps_inference):
        if i in bp_idx:
            context = nullcontext()
            store_attention_maps = 'align'
        else:
            context = torch.inference_mode()
            store_attention_maps = None

        with context:
            timestep = t.expand(latents.shape[0]).to(latents.dtype)
            noise_pred = model(
                x_t=latents,
                t=timestep/1000,
                clip_images=batch['clip_images'], 
                body_clip_images=body_clip_images, 
                prompt_embeds=prompt_embeds,
                pooled_prompt_embeds=pooled_prompt_embeds,
                text_ids=text_ids,
                latent_image_ids=latent_image_ids,
                guidance_vec=guidance,
                arch=arch, 
                arch_body=arch_body,
                vae_embeds=body_vae_embeds,
                accelerator=accelerator, 
                weight_dtype=weight_dtype,
                args=args,
                drop_image_embeds=batch["drop_image_embeds"],
                face_embeddings=batch["face_embeddings"],
                clip_images_ori=batch["clip_images_ori"],
                subject_embedding=subject_embeddings,
                ip_scale=ip_scale,
                store_attention_maps = store_attention_maps,
                lora_scale=lora_scale,
            )
            latents = noise_scheduler.step(noise_pred, t, latents, return_dict=False)[0]
            
            if i in bp_idx:
                current_KQQs = {k: v for k, v in model.KQQs.items()}
                current_Qs = {k: v for k, v in model.Qs.items()}
                KQQs.append(current_KQQs)
                Qs.append(current_Qs)

    if out_image:
        latents = rearrange(
            latents,
            "b (h w) (c ph pw) -> b c (h ph) (w pw)",
            h=image_height // 16,
            w=image_width // 16,
            ph=2,
            pw=2,
        )
        latents = (latents / taef1.config.scaling_factor) + taef1.config.shift_factor
        image = taef1.decode(latents.to(weight_dtype), return_dict=False)[0]
    else:
        image = None
    return KQQs, Qs, image

您好,针对这个训练阶段我还有些问题:我的理解是,他在阶段一是为了优化id projection模型?然后在阶段二和阶段三就不使用sdxl了,而是转为使用lightning模型然后去训练K、V以及再次训练id projection模型?不知道我的理解对不对

JamieCR1999 avatar Jun 11 '25 14:06 JamieCR1999

No description provided.

请问有解决这个问题么?很奇怪我的第二阶段模型不收敛。我是在sd3尝试的方法,第二阶段我在原有的基础上加入了hypersd,这里从纯噪声开始分4步生成图像。图像生成过程中IDformer和crosstention模块都有梯度,vae解码张量也会保留梯度,后续的arcface特征提取也会有梯度。IDformer和crosstention模块是嵌入在hypersd中的,所以我的id loss会在hypersd分支中用梯度更新IDformer和crosstention模块

想请问一下这篇论文在训练的时候是只通过SDXL_lightning训练是吗,并不是SDXL和SDXL_lightning联合训练,因为我对于论文的多分支训练有些搞不清楚,所以问一问尝试复现的人

可以理解为lightning是一个插件,只是在算id loss的时候加进来,diffusion loss我还是在原来的模型上算的。我发一部分代码可以参考一下

                model_pred = model(
                    x_t=x_t,
                    t=t,
                    clip_images=batch['clip_images'], 
                    prompt_embeds=prompt_embeds,
                    pooled_prompt_embeds=pooled_prompt_embeds,
                    text_ids=text_ids,
                    latent_image_ids=latent_image_ids,
                    guidance_vec=guidance_vec,
                    arch=arch, 
                    accelerator=accelerator, 
                    weight_dtype=weight_dtype,
                    args=args,
                    drop_image_embeds=batch["drop_image_embeds"],
                    face_embeddings=batch["face_embeddings"],
                    clip_images_ori=batch["clip_images_ori"],
                    ip_scale=ip_scale,
                    lora_scale=0.0,
                )
 
                loss = F.mse_loss(model_pred.float(), (x_0 - x_1).float(), reduction="mean")

                seed = random.randint(0, 2**32 - 1)
                bp_ts = args.bp_ts
                num_inference_steps = 8
                sample_method = 'random'
                if sample_method == 'random':
                    bp_idx = random.sample(range(num_inference_steps - bp_ts), bp_ts)
                else:  # 'last'
                    bp_idx = [num_inference_steps - bp_ts]
                # Sample time steps for inference
                with torch.inference_mode():
                    KQQs_t2i_temp, Qs_t2i_temp, image_t2i = align_t2i(        
                        args,
                        model, 
                        batch,
                        x_t,
                        prompt_embeds,
                        pooled_prompt_embeds,
                        text_ids,
                        arch,
                        arch_body,
                        body_vae_embeds,
                        subject_embeddings,
                        body_clip_images,
                        accelerator,
                        noise_scheduler,
                        taef1,
                        weight_dtype,
                        ip_scale=0.0,
                        seed=seed,
                        out_image=True,
                        bp_ts=bp_ts,
                        bp_idx=bp_idx,
                        lora_scale=1.0,
                        num_inference_steps=num_inference_steps,
                        image_height=args.resolution[0],
                        image_width=args.resolution[1],
                    )
                KQQs_t2i = [{k: v.detach().clone() for k, v in kqq.items()} for kqq in KQQs_t2i_temp]
                Qs_t2i = [{k: v.detach().clone() for k, v in q.items()} for q in Qs_t2i_temp]
                KQQs_ipa, Qs_ipa, image_ipa = align_t2i(        
                    args,
                    model, 
                    batch,
                    x_t,
                    prompt_embeds,
                    pooled_prompt_embeds,
                    text_ids,
                    arch,
                    arch_body,
                    body_vae_embeds,
                    subject_embeddings,
                    body_clip_images,
                    accelerator,
                    noise_scheduler,
                    taef1,
                    weight_dtype,
                    ip_scale=1.0,
                    seed=seed,
                    out_image=True,
                    bp_ts=bp_ts,
                    bp_idx=bp_idx,
                    lora_scale=1.0,
                    num_inference_steps=num_inference_steps,
                    image_height=args.resolution[0],
                    image_width=args.resolution[1],
                )

                loss_align_semantic = 0.0
                loss_align_layout = 0.0
                for step in range(bp_ts):
                    loss_align_semantic += calcu_loss_align(KQQs_ipa[step], KQQs_t2i[step])
                    loss_align_layout += calcu_loss_align(Qs_t2i[step], Qs_ipa[step])
                loss_align_semantic = loss_align_semantic / bp_ts
                loss_align_layout = loss_align_layout / bp_ts

                loss_dict = criterion_id.forward(image_ipa, pixel_values)
                loss_face = loss_dict['loss_id']

                loss = loss + args.id_loss_weight * loss_face + args.look_align_layout_weight * loss_align_layout + args.look_align_semantic_weight * loss_align_semantic

def align_t2i(
        args,
        model, 
        batch,
        x_t,
        prompt_embeds,
        pooled_prompt_embeds,
        text_ids,
        arch,
        arch_body,
        body_vae_embeds,
        subject_embeddings,
        body_clip_images,
        accelerator,
        noise_scheduler,
        taef1,
        weight_dtype,
        ip_scale,
        seed,
        num_inference_steps=8,
        bp_ts=1,
        out_image=True,
        store_attention_maps='align',
        lora_scale=1.0,
        bp_idx=None,
        image_height=512,
        image_width=512,
        ):
    # Sample time steps for inference
    generator = torch.Generator(device=accelerator.device).manual_seed(seed)
    bs = x_t.shape[0]
    # height = 2 * math.ceil(batch['images'].shape[2] // 16)
    # width = 2 * math.ceil(batch['images'].shape[3] // 16)
    height = 2 * image_height // 16
    width = 2 * image_width // 16
    latents = torch.randn(
        bs,
        16,
        height,
        width,
        device=accelerator.device,
        dtype=x_t.dtype,
        generator=generator,
    )
    latents = rearrange(latents, "b c (h ph) (w pw) -> b (h w) (c ph pw)", ph=2, pw=2)
    latent_image_ids = FluxPipeline._prepare_latent_image_ids(bs, height // 2, width // 2, accelerator.device, x_t.dtype)
    sigmas = np.linspace(1.0, 1 / num_inference_steps, num_inference_steps)
    image_seq_len = latents.shape[1]
    mu = calculate_shift(
        image_seq_len,
        noise_scheduler.config.base_image_seq_len,
        noise_scheduler.config.max_image_seq_len,
        noise_scheduler.config.base_shift,
        noise_scheduler.config.max_shift,
    )
    timesteps_inference, num_inference_steps = retrieve_timesteps(
        noise_scheduler,
        num_inference_steps,
        accelerator.device,
        None,
        sigmas,
        mu=mu,
    )
    guidance = torch.full([1], 4, device=accelerator.device, dtype=torch.float32)
    guidance = guidance.expand(x_t.shape[0])
    
    KQQs = []
    Qs = []
    for i, t in enumerate(timesteps_inference):
        if i in bp_idx:
            context = nullcontext()
            store_attention_maps = 'align'
        else:
            context = torch.inference_mode()
            store_attention_maps = None

        with context:
            timestep = t.expand(latents.shape[0]).to(latents.dtype)
            noise_pred = model(
                x_t=latents,
                t=timestep/1000,
                clip_images=batch['clip_images'], 
                body_clip_images=body_clip_images, 
                prompt_embeds=prompt_embeds,
                pooled_prompt_embeds=pooled_prompt_embeds,
                text_ids=text_ids,
                latent_image_ids=latent_image_ids,
                guidance_vec=guidance,
                arch=arch, 
                arch_body=arch_body,
                vae_embeds=body_vae_embeds,
                accelerator=accelerator, 
                weight_dtype=weight_dtype,
                args=args,
                drop_image_embeds=batch["drop_image_embeds"],
                face_embeddings=batch["face_embeddings"],
                clip_images_ori=batch["clip_images_ori"],
                subject_embedding=subject_embeddings,
                ip_scale=ip_scale,
                store_attention_maps = store_attention_maps,
                lora_scale=lora_scale,
            )
            latents = noise_scheduler.step(noise_pred, t, latents, return_dict=False)[0]
            
            if i in bp_idx:
                current_KQQs = {k: v for k, v in model.KQQs.items()}
                current_Qs = {k: v for k, v in model.Qs.items()}
                KQQs.append(current_KQQs)
                Qs.append(current_Qs)

    if out_image:
        latents = rearrange(
            latents,
            "b (h w) (c ph pw) -> b c (h ph) (w pw)",
            h=image_height // 16,
            w=image_width // 16,
            ph=2,
            pw=2,
        )
        latents = (latents / taef1.config.scaling_factor) + taef1.config.shift_factor
        image = taef1.decode(latents.to(weight_dtype), return_dict=False)[0]
    else:
        image = None
    return KQQs, Qs, image

也就是说,训练中只有第一阶段使用到了sdxl,在其余阶段都是用lightning模型取而代之的?

JamieCR1999 avatar Jun 11 '25 14:06 JamieCR1999

No description provided.

请问有解决这个问题么?很奇怪我的第二阶段模型不收敛。我是在sd3尝试的方法,第二阶段我在原有的基础上加入了hypersd,这里从纯噪声开始分4步生成图像。图像生成过程中IDformer和crosstention模块都有梯度,vae解码张量也会保留梯度,后续的arcface特征提取也会有梯度。IDformer和crosstention模块是嵌入在hypersd中的,所以我的id loss会在hypersd分支中用梯度更新IDformer和crosstention模块

想请问一下这篇论文在训练的时候是只通过SDXL_lightning训练是吗,并不是SDXL和SDXL_lightning联合训练,因为我对于论文的多分支训练有些搞不清楚,所以问一问尝试复现的人

可以理解为lightning是一个插件,只是在算id loss的时候加进来,diffusion loss我还是在原来的模型上算的。我发一部分代码可以参考一下

                model_pred = model(
                    x_t=x_t,
                    t=t,
                    clip_images=batch['clip_images'], 
                    prompt_embeds=prompt_embeds,
                    pooled_prompt_embeds=pooled_prompt_embeds,
                    text_ids=text_ids,
                    latent_image_ids=latent_image_ids,
                    guidance_vec=guidance_vec,
                    arch=arch, 
                    accelerator=accelerator, 
                    weight_dtype=weight_dtype,
                    args=args,
                    drop_image_embeds=batch["drop_image_embeds"],
                    face_embeddings=batch["face_embeddings"],
                    clip_images_ori=batch["clip_images_ori"],
                    ip_scale=ip_scale,
                    lora_scale=0.0,
                )
 
                loss = F.mse_loss(model_pred.float(), (x_0 - x_1).float(), reduction="mean")

                seed = random.randint(0, 2**32 - 1)
                bp_ts = args.bp_ts
                num_inference_steps = 8
                sample_method = 'random'
                if sample_method == 'random':
                    bp_idx = random.sample(range(num_inference_steps - bp_ts), bp_ts)
                else:  # 'last'
                    bp_idx = [num_inference_steps - bp_ts]
                # Sample time steps for inference
                with torch.inference_mode():
                    KQQs_t2i_temp, Qs_t2i_temp, image_t2i = align_t2i(        
                        args,
                        model, 
                        batch,
                        x_t,
                        prompt_embeds,
                        pooled_prompt_embeds,
                        text_ids,
                        arch,
                        arch_body,
                        body_vae_embeds,
                        subject_embeddings,
                        body_clip_images,
                        accelerator,
                        noise_scheduler,
                        taef1,
                        weight_dtype,
                        ip_scale=0.0,
                        seed=seed,
                        out_image=True,
                        bp_ts=bp_ts,
                        bp_idx=bp_idx,
                        lora_scale=1.0,
                        num_inference_steps=num_inference_steps,
                        image_height=args.resolution[0],
                        image_width=args.resolution[1],
                    )
                KQQs_t2i = [{k: v.detach().clone() for k, v in kqq.items()} for kqq in KQQs_t2i_temp]
                Qs_t2i = [{k: v.detach().clone() for k, v in q.items()} for q in Qs_t2i_temp]
                KQQs_ipa, Qs_ipa, image_ipa = align_t2i(        
                    args,
                    model, 
                    batch,
                    x_t,
                    prompt_embeds,
                    pooled_prompt_embeds,
                    text_ids,
                    arch,
                    arch_body,
                    body_vae_embeds,
                    subject_embeddings,
                    body_clip_images,
                    accelerator,
                    noise_scheduler,
                    taef1,
                    weight_dtype,
                    ip_scale=1.0,
                    seed=seed,
                    out_image=True,
                    bp_ts=bp_ts,
                    bp_idx=bp_idx,
                    lora_scale=1.0,
                    num_inference_steps=num_inference_steps,
                    image_height=args.resolution[0],
                    image_width=args.resolution[1],
                )

                loss_align_semantic = 0.0
                loss_align_layout = 0.0
                for step in range(bp_ts):
                    loss_align_semantic += calcu_loss_align(KQQs_ipa[step], KQQs_t2i[step])
                    loss_align_layout += calcu_loss_align(Qs_t2i[step], Qs_ipa[step])
                loss_align_semantic = loss_align_semantic / bp_ts
                loss_align_layout = loss_align_layout / bp_ts

                loss_dict = criterion_id.forward(image_ipa, pixel_values)
                loss_face = loss_dict['loss_id']

                loss = loss + args.id_loss_weight * loss_face + args.look_align_layout_weight * loss_align_layout + args.look_align_semantic_weight * loss_align_semantic

def align_t2i(
        args,
        model, 
        batch,
        x_t,
        prompt_embeds,
        pooled_prompt_embeds,
        text_ids,
        arch,
        arch_body,
        body_vae_embeds,
        subject_embeddings,
        body_clip_images,
        accelerator,
        noise_scheduler,
        taef1,
        weight_dtype,
        ip_scale,
        seed,
        num_inference_steps=8,
        bp_ts=1,
        out_image=True,
        store_attention_maps='align',
        lora_scale=1.0,
        bp_idx=None,
        image_height=512,
        image_width=512,
        ):
    # Sample time steps for inference
    generator = torch.Generator(device=accelerator.device).manual_seed(seed)
    bs = x_t.shape[0]
    # height = 2 * math.ceil(batch['images'].shape[2] // 16)
    # width = 2 * math.ceil(batch['images'].shape[3] // 16)
    height = 2 * image_height // 16
    width = 2 * image_width // 16
    latents = torch.randn(
        bs,
        16,
        height,
        width,
        device=accelerator.device,
        dtype=x_t.dtype,
        generator=generator,
    )
    latents = rearrange(latents, "b c (h ph) (w pw) -> b (h w) (c ph pw)", ph=2, pw=2)
    latent_image_ids = FluxPipeline._prepare_latent_image_ids(bs, height // 2, width // 2, accelerator.device, x_t.dtype)
    sigmas = np.linspace(1.0, 1 / num_inference_steps, num_inference_steps)
    image_seq_len = latents.shape[1]
    mu = calculate_shift(
        image_seq_len,
        noise_scheduler.config.base_image_seq_len,
        noise_scheduler.config.max_image_seq_len,
        noise_scheduler.config.base_shift,
        noise_scheduler.config.max_shift,
    )
    timesteps_inference, num_inference_steps = retrieve_timesteps(
        noise_scheduler,
        num_inference_steps,
        accelerator.device,
        None,
        sigmas,
        mu=mu,
    )
    guidance = torch.full([1], 4, device=accelerator.device, dtype=torch.float32)
    guidance = guidance.expand(x_t.shape[0])
    
    KQQs = []
    Qs = []
    for i, t in enumerate(timesteps_inference):
        if i in bp_idx:
            context = nullcontext()
            store_attention_maps = 'align'
        else:
            context = torch.inference_mode()
            store_attention_maps = None

        with context:
            timestep = t.expand(latents.shape[0]).to(latents.dtype)
            noise_pred = model(
                x_t=latents,
                t=timestep/1000,
                clip_images=batch['clip_images'], 
                body_clip_images=body_clip_images, 
                prompt_embeds=prompt_embeds,
                pooled_prompt_embeds=pooled_prompt_embeds,
                text_ids=text_ids,
                latent_image_ids=latent_image_ids,
                guidance_vec=guidance,
                arch=arch, 
                arch_body=arch_body,
                vae_embeds=body_vae_embeds,
                accelerator=accelerator, 
                weight_dtype=weight_dtype,
                args=args,
                drop_image_embeds=batch["drop_image_embeds"],
                face_embeddings=batch["face_embeddings"],
                clip_images_ori=batch["clip_images_ori"],
                subject_embedding=subject_embeddings,
                ip_scale=ip_scale,
                store_attention_maps = store_attention_maps,
                lora_scale=lora_scale,
            )
            latents = noise_scheduler.step(noise_pred, t, latents, return_dict=False)[0]
            
            if i in bp_idx:
                current_KQQs = {k: v for k, v in model.KQQs.items()}
                current_Qs = {k: v for k, v in model.Qs.items()}
                KQQs.append(current_KQQs)
                Qs.append(current_Qs)

    if out_image:
        latents = rearrange(
            latents,
            "b (h w) (c ph pw) -> b c (h ph) (w pw)",
            h=image_height // 16,
            w=image_width // 16,
            ph=2,
            pw=2,
        )
        latents = (latents / taef1.config.scaling_factor) + taef1.config.shift_factor
        image = taef1.decode(latents.to(weight_dtype), return_dict=False)[0]
    else:
        image = None
    return KQQs, Qs, image

也就是说,训练中只有第一阶段使用到了sdxl,在其余阶段都是用lightning模型取而代之的?

你细读一下论文,Loss为

Image

我个人实现的时候 ,对于第一项,我使用的是原模型(在我这里是flux),而后两个是带了加速lora的flux。 为什么这样可以训,因为lora对于flux也就是一个额外的组件,梯度是可以正常传到要被训的adapter的两个linear层的

xilanhua12138 avatar Jun 12 '25 14:06 xilanhua12138

No description provided.

请问有解决这个问题么?很奇怪我的第二阶段模型不收敛。我是在sd3尝试的方法,第二阶段我在原有的基础上加入了hypersd,这里从纯噪声开始分4步生成图像。图像生成过程中IDformer和crosstention模块都有梯度,vae解码张量也会保留梯度,后续的arcface特征提取也会有梯度。IDformer和crosstention模块是嵌入在hypersd中的,所以我的id loss会在hypersd分支中用梯度更新IDformer和crosstention模块

想请问一下这篇论文在训练的时候是只通过SDXL_lightning训练是吗,并不是SDXL和SDXL_lightning联合训练,因为我对于论文的多分支训练有些搞不清楚,所以问一问尝试复现的人

可以理解为lightning是一个插件,只是在算id loss的时候加进来,diffusion loss我还是在原来的模型上算的。我发一部分代码可以参考一下

                model_pred = model(
                    x_t=x_t,
                    t=t,
                    clip_images=batch['clip_images'], 
                    prompt_embeds=prompt_embeds,
                    pooled_prompt_embeds=pooled_prompt_embeds,
                    text_ids=text_ids,
                    latent_image_ids=latent_image_ids,
                    guidance_vec=guidance_vec,
                    arch=arch, 
                    accelerator=accelerator, 
                    weight_dtype=weight_dtype,
                    args=args,
                    drop_image_embeds=batch["drop_image_embeds"],
                    face_embeddings=batch["face_embeddings"],
                    clip_images_ori=batch["clip_images_ori"],
                    ip_scale=ip_scale,
                    lora_scale=0.0,
                )
 
                loss = F.mse_loss(model_pred.float(), (x_0 - x_1).float(), reduction="mean")

                seed = random.randint(0, 2**32 - 1)
                bp_ts = args.bp_ts
                num_inference_steps = 8
                sample_method = 'random'
                if sample_method == 'random':
                    bp_idx = random.sample(range(num_inference_steps - bp_ts), bp_ts)
                else:  # 'last'
                    bp_idx = [num_inference_steps - bp_ts]
                # Sample time steps for inference
                with torch.inference_mode():
                    KQQs_t2i_temp, Qs_t2i_temp, image_t2i = align_t2i(        
                        args,
                        model, 
                        batch,
                        x_t,
                        prompt_embeds,
                        pooled_prompt_embeds,
                        text_ids,
                        arch,
                        arch_body,
                        body_vae_embeds,
                        subject_embeddings,
                        body_clip_images,
                        accelerator,
                        noise_scheduler,
                        taef1,
                        weight_dtype,
                        ip_scale=0.0,
                        seed=seed,
                        out_image=True,
                        bp_ts=bp_ts,
                        bp_idx=bp_idx,
                        lora_scale=1.0,
                        num_inference_steps=num_inference_steps,
                        image_height=args.resolution[0],
                        image_width=args.resolution[1],
                    )
                KQQs_t2i = [{k: v.detach().clone() for k, v in kqq.items()} for kqq in KQQs_t2i_temp]
                Qs_t2i = [{k: v.detach().clone() for k, v in q.items()} for q in Qs_t2i_temp]
                KQQs_ipa, Qs_ipa, image_ipa = align_t2i(        
                    args,
                    model, 
                    batch,
                    x_t,
                    prompt_embeds,
                    pooled_prompt_embeds,
                    text_ids,
                    arch,
                    arch_body,
                    body_vae_embeds,
                    subject_embeddings,
                    body_clip_images,
                    accelerator,
                    noise_scheduler,
                    taef1,
                    weight_dtype,
                    ip_scale=1.0,
                    seed=seed,
                    out_image=True,
                    bp_ts=bp_ts,
                    bp_idx=bp_idx,
                    lora_scale=1.0,
                    num_inference_steps=num_inference_steps,
                    image_height=args.resolution[0],
                    image_width=args.resolution[1],
                )

                loss_align_semantic = 0.0
                loss_align_layout = 0.0
                for step in range(bp_ts):
                    loss_align_semantic += calcu_loss_align(KQQs_ipa[step], KQQs_t2i[step])
                    loss_align_layout += calcu_loss_align(Qs_t2i[step], Qs_ipa[step])
                loss_align_semantic = loss_align_semantic / bp_ts
                loss_align_layout = loss_align_layout / bp_ts

                loss_dict = criterion_id.forward(image_ipa, pixel_values)
                loss_face = loss_dict['loss_id']

                loss = loss + args.id_loss_weight * loss_face + args.look_align_layout_weight * loss_align_layout + args.look_align_semantic_weight * loss_align_semantic

def align_t2i(
        args,
        model, 
        batch,
        x_t,
        prompt_embeds,
        pooled_prompt_embeds,
        text_ids,
        arch,
        arch_body,
        body_vae_embeds,
        subject_embeddings,
        body_clip_images,
        accelerator,
        noise_scheduler,
        taef1,
        weight_dtype,
        ip_scale,
        seed,
        num_inference_steps=8,
        bp_ts=1,
        out_image=True,
        store_attention_maps='align',
        lora_scale=1.0,
        bp_idx=None,
        image_height=512,
        image_width=512,
        ):
    # Sample time steps for inference
    generator = torch.Generator(device=accelerator.device).manual_seed(seed)
    bs = x_t.shape[0]
    # height = 2 * math.ceil(batch['images'].shape[2] // 16)
    # width = 2 * math.ceil(batch['images'].shape[3] // 16)
    height = 2 * image_height // 16
    width = 2 * image_width // 16
    latents = torch.randn(
        bs,
        16,
        height,
        width,
        device=accelerator.device,
        dtype=x_t.dtype,
        generator=generator,
    )
    latents = rearrange(latents, "b c (h ph) (w pw) -> b (h w) (c ph pw)", ph=2, pw=2)
    latent_image_ids = FluxPipeline._prepare_latent_image_ids(bs, height // 2, width // 2, accelerator.device, x_t.dtype)
    sigmas = np.linspace(1.0, 1 / num_inference_steps, num_inference_steps)
    image_seq_len = latents.shape[1]
    mu = calculate_shift(
        image_seq_len,
        noise_scheduler.config.base_image_seq_len,
        noise_scheduler.config.max_image_seq_len,
        noise_scheduler.config.base_shift,
        noise_scheduler.config.max_shift,
    )
    timesteps_inference, num_inference_steps = retrieve_timesteps(
        noise_scheduler,
        num_inference_steps,
        accelerator.device,
        None,
        sigmas,
        mu=mu,
    )
    guidance = torch.full([1], 4, device=accelerator.device, dtype=torch.float32)
    guidance = guidance.expand(x_t.shape[0])
    
    KQQs = []
    Qs = []
    for i, t in enumerate(timesteps_inference):
        if i in bp_idx:
            context = nullcontext()
            store_attention_maps = 'align'
        else:
            context = torch.inference_mode()
            store_attention_maps = None

        with context:
            timestep = t.expand(latents.shape[0]).to(latents.dtype)
            noise_pred = model(
                x_t=latents,
                t=timestep/1000,
                clip_images=batch['clip_images'], 
                body_clip_images=body_clip_images, 
                prompt_embeds=prompt_embeds,
                pooled_prompt_embeds=pooled_prompt_embeds,
                text_ids=text_ids,
                latent_image_ids=latent_image_ids,
                guidance_vec=guidance,
                arch=arch, 
                arch_body=arch_body,
                vae_embeds=body_vae_embeds,
                accelerator=accelerator, 
                weight_dtype=weight_dtype,
                args=args,
                drop_image_embeds=batch["drop_image_embeds"],
                face_embeddings=batch["face_embeddings"],
                clip_images_ori=batch["clip_images_ori"],
                subject_embedding=subject_embeddings,
                ip_scale=ip_scale,
                store_attention_maps = store_attention_maps,
                lora_scale=lora_scale,
            )
            latents = noise_scheduler.step(noise_pred, t, latents, return_dict=False)[0]
            
            if i in bp_idx:
                current_KQQs = {k: v for k, v in model.KQQs.items()}
                current_Qs = {k: v for k, v in model.Qs.items()}
                KQQs.append(current_KQQs)
                Qs.append(current_Qs)

    if out_image:
        latents = rearrange(
            latents,
            "b (h w) (c ph pw) -> b c (h ph) (w pw)",
            h=image_height // 16,
            w=image_width // 16,
            ph=2,
            pw=2,
        )
        latents = (latents / taef1.config.scaling_factor) + taef1.config.shift_factor
        image = taef1.decode(latents.to(weight_dtype), return_dict=False)[0]
    else:
        image = None
    return KQQs, Qs, image

也就是说,训练中只有第一阶段使用到了sdxl,在其余阶段都是用lightning模型取而代之的?

你细读一下论文,Loss为

Image 我个人实现的时候 ,对于第一项,我使用的是原模型(在我这里是flux),而后两个是带了加速lora的flux。 为什么这样可以训,因为lora对于flux也就是一个额外的组件,梯度是可以正常传到要被训的adapter的两个linear层的

你好可以请问一下在flux模型做复现训练的时候应该选用什么作为计算alignment loss计算时的K和Q吗,在flux模型中和论文里对应的textual features和"UNet features"好像不是很明确@xilanhua12138

EngDoge avatar Oct 11 '25 08:10 EngDoge

你细读一下论文,Loss为 Image 我个人实现的时候 ,对于第一项,我使用的是原模型(在我这里是flux),而后两个是带了加速lora的flux。 为什么这样可以训,因为lora对于flux也就是一个额外的组件,梯度是可以正常传到要被训的adapter的两个linear层的

你好可以请问一下在flux模型做复现训练的时候应该选用什么作为计算alignment loss计算时的K和Q吗,在flux模型中和论文里对应的textual features和"UNet features"好像不是很明确@xilanhua12138

其实很明确,flux 也是分图像模态和语言模态的,所以 Q 应该是图像模态 K 是语言模态

xilanhua12138 avatar Oct 11 '25 08:10 xilanhua12138

你细读一下论文,Loss为 Image 我个人实现的时候 ,对于第一项,我使用的是原模型(在我这里是flux),而后两个是带了加速lora的flux。 为什么这样可以训,因为lora对于flux也就是一个额外的组件,梯度是可以正常传到要被训的adapter的两个linear层的

你好可以请问一下在flux模型做复现训练的时候应该选用什么作为计算alignment loss计算时的K和Q吗,在flux模型中和论文里对应的textual features和"UNet features"好像不是很明确@xilanhua12138

其实很明确,flux 也是分图像模态和语言模态的,所以 Q 应该是图像模态 K 是语言模态

从flux的推理代码来说有区分Double & Single Stream Block,我自己的判断是使用img, txt = block(img=img, txt=txt, vec=vec, pe=pe)计算得到的txt作为K,以及img(double stream)/real_img(single stream) = img + id_weight * self.pulid_ca[ca_idx](id, img)的img/real_img作为Q,不确定是否可行呢@xilanhua12138

EngDoge avatar Oct 11 '25 08:10 EngDoge