How many steps do stage2 and stage3 trained?
@zsxkib @guozinan126 @ToTheBeginning
For the XL model, each of stages 2 and 3 will not exceed 30,000 iterations. We mentioned in the paper that using only ID loss can significantly improve similarity, especially in terms of quantitative metrics, but it can lead to a reward-hacking issue, so the image quality will degrade more seriously after training for a long time.
I am also trying to train the model, but I have some confusion regarding the loss design in the second stage. According to the paper, the lighting branch generates a prediction image from pure noise in 4 steps and calculates the ID loss(cosine similarity) with the original image. This process is not influenced by the diffusion loss branch's tempstep, but rather it is affected only by the most recently updated pulid encoder and pulid_ca models. By adding these two losses together for backpropagation, does the ID loss indirectly affect the gradient updates? Currently, I am quite confused as the model is converging very slowly.@ToTheBeginning @zsxkib @guozinan126
I’m not entirely sure, but one thing you can try is purposefully overfitting on a small batch of data and checking if the convergence behavior matches expectations. This is a common trick in deep learning—if the model can’t overfit a tiny dataset, there’s probably something wrong with the setup. You might want to test this and see if it gives any insight into why convergence is slow
No description provided.
请问有解决这个问题么?很奇怪我的第二阶段模型不收敛。我是在sd3尝试的方法,第二阶段我在原有的基础上加入了hypersd,这里从纯噪声开始分4步生成图像。图像生成过程中IDformer和crosstention模块都有梯度,vae解码张量也会保留梯度,后续的arcface特征提取也会有梯度。IDformer和crosstention模块是嵌入在hypersd中的,所以我的id loss会在hypersd分支中用梯度更新IDformer和crosstention模块
No description provided.
请问有解决这个问题么?很奇怪我的第二阶段模型不收敛。我是在sd3尝试的方法,第二阶段我在原有的基础上加入了hypersd,这里从纯噪声开始分4步生成图像。图像生成过程中IDformer和crosstention模块都有梯度,vae解码张量也会保留梯度,后续的arcface特征提取也会有梯度。IDformer和crosstention模块是嵌入在hypersd中的,所以我的id loss会在hypersd分支中用梯度更新IDformer和crosstention模块
想请问一下这篇论文在训练的时候是只通过SDXL_lightning训练是吗,并不是SDXL和SDXL_lightning联合训练,因为我对于论文的多分支训练有些搞不清楚,所以问一问尝试复现的人
No description provided.
请问有解决这个问题么?很奇怪我的第二阶段模型不收敛。我是在sd3尝试的方法,第二阶段我在原有的基础上加入了hypersd,这里从纯噪声开始分4步生成图像。图像生成过程中IDformer和crosstention模块都有梯度,vae解码张量也会保留梯度,后续的arcface特征提取也会有梯度。IDformer和crosstention模块是嵌入在hypersd中的,所以我的id loss会在hypersd分支中用梯度更新IDformer和crosstention模块
想请问一下这篇论文在训练的时候是只通过SDXL_lightning训练是吗,并不是SDXL和SDXL_lightning联合训练,因为我对于论文的多分支训练有些搞不清楚,所以问一问尝试复现的人
可以理解为lightning是一个插件,只是在算id loss的时候加进来,diffusion loss我还是在原来的模型上算的。我发一部分代码可以参考一下
model_pred = model(
x_t=x_t,
t=t,
clip_images=batch['clip_images'],
prompt_embeds=prompt_embeds,
pooled_prompt_embeds=pooled_prompt_embeds,
text_ids=text_ids,
latent_image_ids=latent_image_ids,
guidance_vec=guidance_vec,
arch=arch,
accelerator=accelerator,
weight_dtype=weight_dtype,
args=args,
drop_image_embeds=batch["drop_image_embeds"],
face_embeddings=batch["face_embeddings"],
clip_images_ori=batch["clip_images_ori"],
ip_scale=ip_scale,
lora_scale=0.0,
)
loss = F.mse_loss(model_pred.float(), (x_0 - x_1).float(), reduction="mean")
seed = random.randint(0, 2**32 - 1)
bp_ts = args.bp_ts
num_inference_steps = 8
sample_method = 'random'
if sample_method == 'random':
bp_idx = random.sample(range(num_inference_steps - bp_ts), bp_ts)
else: # 'last'
bp_idx = [num_inference_steps - bp_ts]
# Sample time steps for inference
with torch.inference_mode():
KQQs_t2i_temp, Qs_t2i_temp, image_t2i = align_t2i(
args,
model,
batch,
x_t,
prompt_embeds,
pooled_prompt_embeds,
text_ids,
arch,
arch_body,
body_vae_embeds,
subject_embeddings,
body_clip_images,
accelerator,
noise_scheduler,
taef1,
weight_dtype,
ip_scale=0.0,
seed=seed,
out_image=True,
bp_ts=bp_ts,
bp_idx=bp_idx,
lora_scale=1.0,
num_inference_steps=num_inference_steps,
image_height=args.resolution[0],
image_width=args.resolution[1],
)
KQQs_t2i = [{k: v.detach().clone() for k, v in kqq.items()} for kqq in KQQs_t2i_temp]
Qs_t2i = [{k: v.detach().clone() for k, v in q.items()} for q in Qs_t2i_temp]
KQQs_ipa, Qs_ipa, image_ipa = align_t2i(
args,
model,
batch,
x_t,
prompt_embeds,
pooled_prompt_embeds,
text_ids,
arch,
arch_body,
body_vae_embeds,
subject_embeddings,
body_clip_images,
accelerator,
noise_scheduler,
taef1,
weight_dtype,
ip_scale=1.0,
seed=seed,
out_image=True,
bp_ts=bp_ts,
bp_idx=bp_idx,
lora_scale=1.0,
num_inference_steps=num_inference_steps,
image_height=args.resolution[0],
image_width=args.resolution[1],
)
loss_align_semantic = 0.0
loss_align_layout = 0.0
for step in range(bp_ts):
loss_align_semantic += calcu_loss_align(KQQs_ipa[step], KQQs_t2i[step])
loss_align_layout += calcu_loss_align(Qs_t2i[step], Qs_ipa[step])
loss_align_semantic = loss_align_semantic / bp_ts
loss_align_layout = loss_align_layout / bp_ts
loss_dict = criterion_id.forward(image_ipa, pixel_values)
loss_face = loss_dict['loss_id']
loss = loss + args.id_loss_weight * loss_face + args.look_align_layout_weight * loss_align_layout + args.look_align_semantic_weight * loss_align_semantic
def align_t2i(
args,
model,
batch,
x_t,
prompt_embeds,
pooled_prompt_embeds,
text_ids,
arch,
arch_body,
body_vae_embeds,
subject_embeddings,
body_clip_images,
accelerator,
noise_scheduler,
taef1,
weight_dtype,
ip_scale,
seed,
num_inference_steps=8,
bp_ts=1,
out_image=True,
store_attention_maps='align',
lora_scale=1.0,
bp_idx=None,
image_height=512,
image_width=512,
):
# Sample time steps for inference
generator = torch.Generator(device=accelerator.device).manual_seed(seed)
bs = x_t.shape[0]
# height = 2 * math.ceil(batch['images'].shape[2] // 16)
# width = 2 * math.ceil(batch['images'].shape[3] // 16)
height = 2 * image_height // 16
width = 2 * image_width // 16
latents = torch.randn(
bs,
16,
height,
width,
device=accelerator.device,
dtype=x_t.dtype,
generator=generator,
)
latents = rearrange(latents, "b c (h ph) (w pw) -> b (h w) (c ph pw)", ph=2, pw=2)
latent_image_ids = FluxPipeline._prepare_latent_image_ids(bs, height // 2, width // 2, accelerator.device, x_t.dtype)
sigmas = np.linspace(1.0, 1 / num_inference_steps, num_inference_steps)
image_seq_len = latents.shape[1]
mu = calculate_shift(
image_seq_len,
noise_scheduler.config.base_image_seq_len,
noise_scheduler.config.max_image_seq_len,
noise_scheduler.config.base_shift,
noise_scheduler.config.max_shift,
)
timesteps_inference, num_inference_steps = retrieve_timesteps(
noise_scheduler,
num_inference_steps,
accelerator.device,
None,
sigmas,
mu=mu,
)
guidance = torch.full([1], 4, device=accelerator.device, dtype=torch.float32)
guidance = guidance.expand(x_t.shape[0])
KQQs = []
Qs = []
for i, t in enumerate(timesteps_inference):
if i in bp_idx:
context = nullcontext()
store_attention_maps = 'align'
else:
context = torch.inference_mode()
store_attention_maps = None
with context:
timestep = t.expand(latents.shape[0]).to(latents.dtype)
noise_pred = model(
x_t=latents,
t=timestep/1000,
clip_images=batch['clip_images'],
body_clip_images=body_clip_images,
prompt_embeds=prompt_embeds,
pooled_prompt_embeds=pooled_prompt_embeds,
text_ids=text_ids,
latent_image_ids=latent_image_ids,
guidance_vec=guidance,
arch=arch,
arch_body=arch_body,
vae_embeds=body_vae_embeds,
accelerator=accelerator,
weight_dtype=weight_dtype,
args=args,
drop_image_embeds=batch["drop_image_embeds"],
face_embeddings=batch["face_embeddings"],
clip_images_ori=batch["clip_images_ori"],
subject_embedding=subject_embeddings,
ip_scale=ip_scale,
store_attention_maps = store_attention_maps,
lora_scale=lora_scale,
)
latents = noise_scheduler.step(noise_pred, t, latents, return_dict=False)[0]
if i in bp_idx:
current_KQQs = {k: v for k, v in model.KQQs.items()}
current_Qs = {k: v for k, v in model.Qs.items()}
KQQs.append(current_KQQs)
Qs.append(current_Qs)
if out_image:
latents = rearrange(
latents,
"b (h w) (c ph pw) -> b c (h ph) (w pw)",
h=image_height // 16,
w=image_width // 16,
ph=2,
pw=2,
)
latents = (latents / taef1.config.scaling_factor) + taef1.config.shift_factor
image = taef1.decode(latents.to(weight_dtype), return_dict=False)[0]
else:
image = None
return KQQs, Qs, image
No description provided.
请问有解决这个问题么?很奇怪我的第二阶段模型不收敛。我是在sd3尝试的方法,第二阶段我在原有的基础上加入了hypersd,这里从纯噪声开始分4步生成图像。图像生成过程中IDformer和crosstention模块都有梯度,vae解码张量也会保留梯度,后续的arcface特征提取也会有梯度。IDformer和crosstention模块是嵌入在hypersd中的,所以我的id loss会在hypersd分支中用梯度更新IDformer和crosstention模块
想请问一下这篇论文在训练的时候是只通过SDXL_lightning训练是吗,并不是SDXL和SDXL_lightning联合训练,因为我对于论文的多分支训练有些搞不清楚,所以问一问尝试复现的人
可以理解为lightning是一个插件,只是在算id loss的时候加进来,diffusion loss我还是在原来的模型上算的。我发一部分代码可以参考一下
model_pred = model( x_t=x_t, t=t, clip_images=batch['clip_images'], prompt_embeds=prompt_embeds, pooled_prompt_embeds=pooled_prompt_embeds, text_ids=text_ids, latent_image_ids=latent_image_ids, guidance_vec=guidance_vec, arch=arch, accelerator=accelerator, weight_dtype=weight_dtype, args=args, drop_image_embeds=batch["drop_image_embeds"], face_embeddings=batch["face_embeddings"], clip_images_ori=batch["clip_images_ori"], ip_scale=ip_scale, lora_scale=0.0, ) loss = F.mse_loss(model_pred.float(), (x_0 - x_1).float(), reduction="mean") seed = random.randint(0, 2**32 - 1) bp_ts = args.bp_ts num_inference_steps = 8 sample_method = 'random' if sample_method == 'random': bp_idx = random.sample(range(num_inference_steps - bp_ts), bp_ts) else: # 'last' bp_idx = [num_inference_steps - bp_ts] # Sample time steps for inference with torch.inference_mode(): KQQs_t2i_temp, Qs_t2i_temp, image_t2i = align_t2i( args, model, batch, x_t, prompt_embeds, pooled_prompt_embeds, text_ids, arch, arch_body, body_vae_embeds, subject_embeddings, body_clip_images, accelerator, noise_scheduler, taef1, weight_dtype, ip_scale=0.0, seed=seed, out_image=True, bp_ts=bp_ts, bp_idx=bp_idx, lora_scale=1.0, num_inference_steps=num_inference_steps, image_height=args.resolution[0], image_width=args.resolution[1], ) KQQs_t2i = [{k: v.detach().clone() for k, v in kqq.items()} for kqq in KQQs_t2i_temp] Qs_t2i = [{k: v.detach().clone() for k, v in q.items()} for q in Qs_t2i_temp] KQQs_ipa, Qs_ipa, image_ipa = align_t2i( args, model, batch, x_t, prompt_embeds, pooled_prompt_embeds, text_ids, arch, arch_body, body_vae_embeds, subject_embeddings, body_clip_images, accelerator, noise_scheduler, taef1, weight_dtype, ip_scale=1.0, seed=seed, out_image=True, bp_ts=bp_ts, bp_idx=bp_idx, lora_scale=1.0, num_inference_steps=num_inference_steps, image_height=args.resolution[0], image_width=args.resolution[1], ) loss_align_semantic = 0.0 loss_align_layout = 0.0 for step in range(bp_ts): loss_align_semantic += calcu_loss_align(KQQs_ipa[step], KQQs_t2i[step]) loss_align_layout += calcu_loss_align(Qs_t2i[step], Qs_ipa[step]) loss_align_semantic = loss_align_semantic / bp_ts loss_align_layout = loss_align_layout / bp_ts loss_dict = criterion_id.forward(image_ipa, pixel_values) loss_face = loss_dict['loss_id'] loss = loss + args.id_loss_weight * loss_face + args.look_align_layout_weight * loss_align_layout + args.look_align_semantic_weight * loss_align_semantic def align_t2i( args, model, batch, x_t, prompt_embeds, pooled_prompt_embeds, text_ids, arch, arch_body, body_vae_embeds, subject_embeddings, body_clip_images, accelerator, noise_scheduler, taef1, weight_dtype, ip_scale, seed, num_inference_steps=8, bp_ts=1, out_image=True, store_attention_maps='align', lora_scale=1.0, bp_idx=None, image_height=512, image_width=512, ): # Sample time steps for inference generator = torch.Generator(device=accelerator.device).manual_seed(seed) bs = x_t.shape[0] # height = 2 * math.ceil(batch['images'].shape[2] // 16) # width = 2 * math.ceil(batch['images'].shape[3] // 16) height = 2 * image_height // 16 width = 2 * image_width // 16 latents = torch.randn( bs, 16, height, width, device=accelerator.device, dtype=x_t.dtype, generator=generator, ) latents = rearrange(latents, "b c (h ph) (w pw) -> b (h w) (c ph pw)", ph=2, pw=2) latent_image_ids = FluxPipeline._prepare_latent_image_ids(bs, height // 2, width // 2, accelerator.device, x_t.dtype) sigmas = np.linspace(1.0, 1 / num_inference_steps, num_inference_steps) image_seq_len = latents.shape[1] mu = calculate_shift( image_seq_len, noise_scheduler.config.base_image_seq_len, noise_scheduler.config.max_image_seq_len, noise_scheduler.config.base_shift, noise_scheduler.config.max_shift, ) timesteps_inference, num_inference_steps = retrieve_timesteps( noise_scheduler, num_inference_steps, accelerator.device, None, sigmas, mu=mu, ) guidance = torch.full([1], 4, device=accelerator.device, dtype=torch.float32) guidance = guidance.expand(x_t.shape[0]) KQQs = [] Qs = [] for i, t in enumerate(timesteps_inference): if i in bp_idx: context = nullcontext() store_attention_maps = 'align' else: context = torch.inference_mode() store_attention_maps = None with context: timestep = t.expand(latents.shape[0]).to(latents.dtype) noise_pred = model( x_t=latents, t=timestep/1000, clip_images=batch['clip_images'], body_clip_images=body_clip_images, prompt_embeds=prompt_embeds, pooled_prompt_embeds=pooled_prompt_embeds, text_ids=text_ids, latent_image_ids=latent_image_ids, guidance_vec=guidance, arch=arch, arch_body=arch_body, vae_embeds=body_vae_embeds, accelerator=accelerator, weight_dtype=weight_dtype, args=args, drop_image_embeds=batch["drop_image_embeds"], face_embeddings=batch["face_embeddings"], clip_images_ori=batch["clip_images_ori"], subject_embedding=subject_embeddings, ip_scale=ip_scale, store_attention_maps = store_attention_maps, lora_scale=lora_scale, ) latents = noise_scheduler.step(noise_pred, t, latents, return_dict=False)[0] if i in bp_idx: current_KQQs = {k: v for k, v in model.KQQs.items()} current_Qs = {k: v for k, v in model.Qs.items()} KQQs.append(current_KQQs) Qs.append(current_Qs) if out_image: latents = rearrange( latents, "b (h w) (c ph pw) -> b c (h ph) (w pw)", h=image_height // 16, w=image_width // 16, ph=2, pw=2, ) latents = (latents / taef1.config.scaling_factor) + taef1.config.shift_factor image = taef1.decode(latents.to(weight_dtype), return_dict=False)[0] else: image = None return KQQs, Qs, image
感谢!我去学习一下!
No description provided.
请问有解决这个问题么?很奇怪我的第二阶段模型不收敛。我是在sd3尝试的方法,第二阶段我在原有的基础上加入了hypersd,这里从纯噪声开始分4步生成图像。图像生成过程中IDformer和crosstention模块都有梯度,vae解码张量也会保留梯度,后续的arcface特征提取也会有梯度。IDformer和crosstention模块是嵌入在hypersd中的,所以我的id loss会在hypersd分支中用梯度更新IDformer和crosstention模块
想请问一下这篇论文在训练的时候是只通过SDXL_lightning训练是吗,并不是SDXL和SDXL_lightning联合训练,因为我对于论文的多分支训练有些搞不清楚,所以问一问尝试复现的人
可以理解为lightning是一个插件,只是在算id loss的时候加进来,diffusion loss我还是在原来的模型上算的。我发一部分代码可以参考一下
model_pred = model( x_t=x_t, t=t, clip_images=batch['clip_images'], prompt_embeds=prompt_embeds, pooled_prompt_embeds=pooled_prompt_embeds, text_ids=text_ids, latent_image_ids=latent_image_ids, guidance_vec=guidance_vec, arch=arch, accelerator=accelerator, weight_dtype=weight_dtype, args=args, drop_image_embeds=batch["drop_image_embeds"], face_embeddings=batch["face_embeddings"], clip_images_ori=batch["clip_images_ori"], ip_scale=ip_scale, lora_scale=0.0, ) loss = F.mse_loss(model_pred.float(), (x_0 - x_1).float(), reduction="mean") seed = random.randint(0, 2**32 - 1) bp_ts = args.bp_ts num_inference_steps = 8 sample_method = 'random' if sample_method == 'random': bp_idx = random.sample(range(num_inference_steps - bp_ts), bp_ts) else: # 'last' bp_idx = [num_inference_steps - bp_ts] # Sample time steps for inference with torch.inference_mode(): KQQs_t2i_temp, Qs_t2i_temp, image_t2i = align_t2i( args, model, batch, x_t, prompt_embeds, pooled_prompt_embeds, text_ids, arch, arch_body, body_vae_embeds, subject_embeddings, body_clip_images, accelerator, noise_scheduler, taef1, weight_dtype, ip_scale=0.0, seed=seed, out_image=True, bp_ts=bp_ts, bp_idx=bp_idx, lora_scale=1.0, num_inference_steps=num_inference_steps, image_height=args.resolution[0], image_width=args.resolution[1], ) KQQs_t2i = [{k: v.detach().clone() for k, v in kqq.items()} for kqq in KQQs_t2i_temp] Qs_t2i = [{k: v.detach().clone() for k, v in q.items()} for q in Qs_t2i_temp] KQQs_ipa, Qs_ipa, image_ipa = align_t2i( args, model, batch, x_t, prompt_embeds, pooled_prompt_embeds, text_ids, arch, arch_body, body_vae_embeds, subject_embeddings, body_clip_images, accelerator, noise_scheduler, taef1, weight_dtype, ip_scale=1.0, seed=seed, out_image=True, bp_ts=bp_ts, bp_idx=bp_idx, lora_scale=1.0, num_inference_steps=num_inference_steps, image_height=args.resolution[0], image_width=args.resolution[1], ) loss_align_semantic = 0.0 loss_align_layout = 0.0 for step in range(bp_ts): loss_align_semantic += calcu_loss_align(KQQs_ipa[step], KQQs_t2i[step]) loss_align_layout += calcu_loss_align(Qs_t2i[step], Qs_ipa[step]) loss_align_semantic = loss_align_semantic / bp_ts loss_align_layout = loss_align_layout / bp_ts loss_dict = criterion_id.forward(image_ipa, pixel_values) loss_face = loss_dict['loss_id'] loss = loss + args.id_loss_weight * loss_face + args.look_align_layout_weight * loss_align_layout + args.look_align_semantic_weight * loss_align_semantic def align_t2i( args, model, batch, x_t, prompt_embeds, pooled_prompt_embeds, text_ids, arch, arch_body, body_vae_embeds, subject_embeddings, body_clip_images, accelerator, noise_scheduler, taef1, weight_dtype, ip_scale, seed, num_inference_steps=8, bp_ts=1, out_image=True, store_attention_maps='align', lora_scale=1.0, bp_idx=None, image_height=512, image_width=512, ): # Sample time steps for inference generator = torch.Generator(device=accelerator.device).manual_seed(seed) bs = x_t.shape[0] # height = 2 * math.ceil(batch['images'].shape[2] // 16) # width = 2 * math.ceil(batch['images'].shape[3] // 16) height = 2 * image_height // 16 width = 2 * image_width // 16 latents = torch.randn( bs, 16, height, width, device=accelerator.device, dtype=x_t.dtype, generator=generator, ) latents = rearrange(latents, "b c (h ph) (w pw) -> b (h w) (c ph pw)", ph=2, pw=2) latent_image_ids = FluxPipeline._prepare_latent_image_ids(bs, height // 2, width // 2, accelerator.device, x_t.dtype) sigmas = np.linspace(1.0, 1 / num_inference_steps, num_inference_steps) image_seq_len = latents.shape[1] mu = calculate_shift( image_seq_len, noise_scheduler.config.base_image_seq_len, noise_scheduler.config.max_image_seq_len, noise_scheduler.config.base_shift, noise_scheduler.config.max_shift, ) timesteps_inference, num_inference_steps = retrieve_timesteps( noise_scheduler, num_inference_steps, accelerator.device, None, sigmas, mu=mu, ) guidance = torch.full([1], 4, device=accelerator.device, dtype=torch.float32) guidance = guidance.expand(x_t.shape[0]) KQQs = [] Qs = [] for i, t in enumerate(timesteps_inference): if i in bp_idx: context = nullcontext() store_attention_maps = 'align' else: context = torch.inference_mode() store_attention_maps = None with context: timestep = t.expand(latents.shape[0]).to(latents.dtype) noise_pred = model( x_t=latents, t=timestep/1000, clip_images=batch['clip_images'], body_clip_images=body_clip_images, prompt_embeds=prompt_embeds, pooled_prompt_embeds=pooled_prompt_embeds, text_ids=text_ids, latent_image_ids=latent_image_ids, guidance_vec=guidance, arch=arch, arch_body=arch_body, vae_embeds=body_vae_embeds, accelerator=accelerator, weight_dtype=weight_dtype, args=args, drop_image_embeds=batch["drop_image_embeds"], face_embeddings=batch["face_embeddings"], clip_images_ori=batch["clip_images_ori"], subject_embedding=subject_embeddings, ip_scale=ip_scale, store_attention_maps = store_attention_maps, lora_scale=lora_scale, ) latents = noise_scheduler.step(noise_pred, t, latents, return_dict=False)[0] if i in bp_idx: current_KQQs = {k: v for k, v in model.KQQs.items()} current_Qs = {k: v for k, v in model.Qs.items()} KQQs.append(current_KQQs) Qs.append(current_Qs) if out_image: latents = rearrange( latents, "b (h w) (c ph pw) -> b c (h ph) (w pw)", h=image_height // 16, w=image_width // 16, ph=2, pw=2, ) latents = (latents / taef1.config.scaling_factor) + taef1.config.shift_factor image = taef1.decode(latents.to(weight_dtype), return_dict=False)[0] else: image = None return KQQs, Qs, image
您好,针对这个训练阶段我还有些问题:我的理解是,他在阶段一是为了优化id projection模型?然后在阶段二和阶段三就不使用sdxl了,而是转为使用lightning模型然后去训练K、V以及再次训练id projection模型?不知道我的理解对不对
No description provided.
请问有解决这个问题么?很奇怪我的第二阶段模型不收敛。我是在sd3尝试的方法,第二阶段我在原有的基础上加入了hypersd,这里从纯噪声开始分4步生成图像。图像生成过程中IDformer和crosstention模块都有梯度,vae解码张量也会保留梯度,后续的arcface特征提取也会有梯度。IDformer和crosstention模块是嵌入在hypersd中的,所以我的id loss会在hypersd分支中用梯度更新IDformer和crosstention模块
想请问一下这篇论文在训练的时候是只通过SDXL_lightning训练是吗,并不是SDXL和SDXL_lightning联合训练,因为我对于论文的多分支训练有些搞不清楚,所以问一问尝试复现的人
可以理解为lightning是一个插件,只是在算id loss的时候加进来,diffusion loss我还是在原来的模型上算的。我发一部分代码可以参考一下
model_pred = model( x_t=x_t, t=t, clip_images=batch['clip_images'], prompt_embeds=prompt_embeds, pooled_prompt_embeds=pooled_prompt_embeds, text_ids=text_ids, latent_image_ids=latent_image_ids, guidance_vec=guidance_vec, arch=arch, accelerator=accelerator, weight_dtype=weight_dtype, args=args, drop_image_embeds=batch["drop_image_embeds"], face_embeddings=batch["face_embeddings"], clip_images_ori=batch["clip_images_ori"], ip_scale=ip_scale, lora_scale=0.0, ) loss = F.mse_loss(model_pred.float(), (x_0 - x_1).float(), reduction="mean") seed = random.randint(0, 2**32 - 1) bp_ts = args.bp_ts num_inference_steps = 8 sample_method = 'random' if sample_method == 'random': bp_idx = random.sample(range(num_inference_steps - bp_ts), bp_ts) else: # 'last' bp_idx = [num_inference_steps - bp_ts] # Sample time steps for inference with torch.inference_mode(): KQQs_t2i_temp, Qs_t2i_temp, image_t2i = align_t2i( args, model, batch, x_t, prompt_embeds, pooled_prompt_embeds, text_ids, arch, arch_body, body_vae_embeds, subject_embeddings, body_clip_images, accelerator, noise_scheduler, taef1, weight_dtype, ip_scale=0.0, seed=seed, out_image=True, bp_ts=bp_ts, bp_idx=bp_idx, lora_scale=1.0, num_inference_steps=num_inference_steps, image_height=args.resolution[0], image_width=args.resolution[1], ) KQQs_t2i = [{k: v.detach().clone() for k, v in kqq.items()} for kqq in KQQs_t2i_temp] Qs_t2i = [{k: v.detach().clone() for k, v in q.items()} for q in Qs_t2i_temp] KQQs_ipa, Qs_ipa, image_ipa = align_t2i( args, model, batch, x_t, prompt_embeds, pooled_prompt_embeds, text_ids, arch, arch_body, body_vae_embeds, subject_embeddings, body_clip_images, accelerator, noise_scheduler, taef1, weight_dtype, ip_scale=1.0, seed=seed, out_image=True, bp_ts=bp_ts, bp_idx=bp_idx, lora_scale=1.0, num_inference_steps=num_inference_steps, image_height=args.resolution[0], image_width=args.resolution[1], ) loss_align_semantic = 0.0 loss_align_layout = 0.0 for step in range(bp_ts): loss_align_semantic += calcu_loss_align(KQQs_ipa[step], KQQs_t2i[step]) loss_align_layout += calcu_loss_align(Qs_t2i[step], Qs_ipa[step]) loss_align_semantic = loss_align_semantic / bp_ts loss_align_layout = loss_align_layout / bp_ts loss_dict = criterion_id.forward(image_ipa, pixel_values) loss_face = loss_dict['loss_id'] loss = loss + args.id_loss_weight * loss_face + args.look_align_layout_weight * loss_align_layout + args.look_align_semantic_weight * loss_align_semantic def align_t2i( args, model, batch, x_t, prompt_embeds, pooled_prompt_embeds, text_ids, arch, arch_body, body_vae_embeds, subject_embeddings, body_clip_images, accelerator, noise_scheduler, taef1, weight_dtype, ip_scale, seed, num_inference_steps=8, bp_ts=1, out_image=True, store_attention_maps='align', lora_scale=1.0, bp_idx=None, image_height=512, image_width=512, ): # Sample time steps for inference generator = torch.Generator(device=accelerator.device).manual_seed(seed) bs = x_t.shape[0] # height = 2 * math.ceil(batch['images'].shape[2] // 16) # width = 2 * math.ceil(batch['images'].shape[3] // 16) height = 2 * image_height // 16 width = 2 * image_width // 16 latents = torch.randn( bs, 16, height, width, device=accelerator.device, dtype=x_t.dtype, generator=generator, ) latents = rearrange(latents, "b c (h ph) (w pw) -> b (h w) (c ph pw)", ph=2, pw=2) latent_image_ids = FluxPipeline._prepare_latent_image_ids(bs, height // 2, width // 2, accelerator.device, x_t.dtype) sigmas = np.linspace(1.0, 1 / num_inference_steps, num_inference_steps) image_seq_len = latents.shape[1] mu = calculate_shift( image_seq_len, noise_scheduler.config.base_image_seq_len, noise_scheduler.config.max_image_seq_len, noise_scheduler.config.base_shift, noise_scheduler.config.max_shift, ) timesteps_inference, num_inference_steps = retrieve_timesteps( noise_scheduler, num_inference_steps, accelerator.device, None, sigmas, mu=mu, ) guidance = torch.full([1], 4, device=accelerator.device, dtype=torch.float32) guidance = guidance.expand(x_t.shape[0]) KQQs = [] Qs = [] for i, t in enumerate(timesteps_inference): if i in bp_idx: context = nullcontext() store_attention_maps = 'align' else: context = torch.inference_mode() store_attention_maps = None with context: timestep = t.expand(latents.shape[0]).to(latents.dtype) noise_pred = model( x_t=latents, t=timestep/1000, clip_images=batch['clip_images'], body_clip_images=body_clip_images, prompt_embeds=prompt_embeds, pooled_prompt_embeds=pooled_prompt_embeds, text_ids=text_ids, latent_image_ids=latent_image_ids, guidance_vec=guidance, arch=arch, arch_body=arch_body, vae_embeds=body_vae_embeds, accelerator=accelerator, weight_dtype=weight_dtype, args=args, drop_image_embeds=batch["drop_image_embeds"], face_embeddings=batch["face_embeddings"], clip_images_ori=batch["clip_images_ori"], subject_embedding=subject_embeddings, ip_scale=ip_scale, store_attention_maps = store_attention_maps, lora_scale=lora_scale, ) latents = noise_scheduler.step(noise_pred, t, latents, return_dict=False)[0] if i in bp_idx: current_KQQs = {k: v for k, v in model.KQQs.items()} current_Qs = {k: v for k, v in model.Qs.items()} KQQs.append(current_KQQs) Qs.append(current_Qs) if out_image: latents = rearrange( latents, "b (h w) (c ph pw) -> b c (h ph) (w pw)", h=image_height // 16, w=image_width // 16, ph=2, pw=2, ) latents = (latents / taef1.config.scaling_factor) + taef1.config.shift_factor image = taef1.decode(latents.to(weight_dtype), return_dict=False)[0] else: image = None return KQQs, Qs, image
也就是说,训练中只有第一阶段使用到了sdxl,在其余阶段都是用lightning模型取而代之的?
No description provided.
请问有解决这个问题么?很奇怪我的第二阶段模型不收敛。我是在sd3尝试的方法,第二阶段我在原有的基础上加入了hypersd,这里从纯噪声开始分4步生成图像。图像生成过程中IDformer和crosstention模块都有梯度,vae解码张量也会保留梯度,后续的arcface特征提取也会有梯度。IDformer和crosstention模块是嵌入在hypersd中的,所以我的id loss会在hypersd分支中用梯度更新IDformer和crosstention模块
想请问一下这篇论文在训练的时候是只通过SDXL_lightning训练是吗,并不是SDXL和SDXL_lightning联合训练,因为我对于论文的多分支训练有些搞不清楚,所以问一问尝试复现的人
可以理解为lightning是一个插件,只是在算id loss的时候加进来,diffusion loss我还是在原来的模型上算的。我发一部分代码可以参考一下
model_pred = model( x_t=x_t, t=t, clip_images=batch['clip_images'], prompt_embeds=prompt_embeds, pooled_prompt_embeds=pooled_prompt_embeds, text_ids=text_ids, latent_image_ids=latent_image_ids, guidance_vec=guidance_vec, arch=arch, accelerator=accelerator, weight_dtype=weight_dtype, args=args, drop_image_embeds=batch["drop_image_embeds"], face_embeddings=batch["face_embeddings"], clip_images_ori=batch["clip_images_ori"], ip_scale=ip_scale, lora_scale=0.0, ) loss = F.mse_loss(model_pred.float(), (x_0 - x_1).float(), reduction="mean") seed = random.randint(0, 2**32 - 1) bp_ts = args.bp_ts num_inference_steps = 8 sample_method = 'random' if sample_method == 'random': bp_idx = random.sample(range(num_inference_steps - bp_ts), bp_ts) else: # 'last' bp_idx = [num_inference_steps - bp_ts] # Sample time steps for inference with torch.inference_mode(): KQQs_t2i_temp, Qs_t2i_temp, image_t2i = align_t2i( args, model, batch, x_t, prompt_embeds, pooled_prompt_embeds, text_ids, arch, arch_body, body_vae_embeds, subject_embeddings, body_clip_images, accelerator, noise_scheduler, taef1, weight_dtype, ip_scale=0.0, seed=seed, out_image=True, bp_ts=bp_ts, bp_idx=bp_idx, lora_scale=1.0, num_inference_steps=num_inference_steps, image_height=args.resolution[0], image_width=args.resolution[1], ) KQQs_t2i = [{k: v.detach().clone() for k, v in kqq.items()} for kqq in KQQs_t2i_temp] Qs_t2i = [{k: v.detach().clone() for k, v in q.items()} for q in Qs_t2i_temp] KQQs_ipa, Qs_ipa, image_ipa = align_t2i( args, model, batch, x_t, prompt_embeds, pooled_prompt_embeds, text_ids, arch, arch_body, body_vae_embeds, subject_embeddings, body_clip_images, accelerator, noise_scheduler, taef1, weight_dtype, ip_scale=1.0, seed=seed, out_image=True, bp_ts=bp_ts, bp_idx=bp_idx, lora_scale=1.0, num_inference_steps=num_inference_steps, image_height=args.resolution[0], image_width=args.resolution[1], ) loss_align_semantic = 0.0 loss_align_layout = 0.0 for step in range(bp_ts): loss_align_semantic += calcu_loss_align(KQQs_ipa[step], KQQs_t2i[step]) loss_align_layout += calcu_loss_align(Qs_t2i[step], Qs_ipa[step]) loss_align_semantic = loss_align_semantic / bp_ts loss_align_layout = loss_align_layout / bp_ts loss_dict = criterion_id.forward(image_ipa, pixel_values) loss_face = loss_dict['loss_id'] loss = loss + args.id_loss_weight * loss_face + args.look_align_layout_weight * loss_align_layout + args.look_align_semantic_weight * loss_align_semantic def align_t2i( args, model, batch, x_t, prompt_embeds, pooled_prompt_embeds, text_ids, arch, arch_body, body_vae_embeds, subject_embeddings, body_clip_images, accelerator, noise_scheduler, taef1, weight_dtype, ip_scale, seed, num_inference_steps=8, bp_ts=1, out_image=True, store_attention_maps='align', lora_scale=1.0, bp_idx=None, image_height=512, image_width=512, ): # Sample time steps for inference generator = torch.Generator(device=accelerator.device).manual_seed(seed) bs = x_t.shape[0] # height = 2 * math.ceil(batch['images'].shape[2] // 16) # width = 2 * math.ceil(batch['images'].shape[3] // 16) height = 2 * image_height // 16 width = 2 * image_width // 16 latents = torch.randn( bs, 16, height, width, device=accelerator.device, dtype=x_t.dtype, generator=generator, ) latents = rearrange(latents, "b c (h ph) (w pw) -> b (h w) (c ph pw)", ph=2, pw=2) latent_image_ids = FluxPipeline._prepare_latent_image_ids(bs, height // 2, width // 2, accelerator.device, x_t.dtype) sigmas = np.linspace(1.0, 1 / num_inference_steps, num_inference_steps) image_seq_len = latents.shape[1] mu = calculate_shift( image_seq_len, noise_scheduler.config.base_image_seq_len, noise_scheduler.config.max_image_seq_len, noise_scheduler.config.base_shift, noise_scheduler.config.max_shift, ) timesteps_inference, num_inference_steps = retrieve_timesteps( noise_scheduler, num_inference_steps, accelerator.device, None, sigmas, mu=mu, ) guidance = torch.full([1], 4, device=accelerator.device, dtype=torch.float32) guidance = guidance.expand(x_t.shape[0]) KQQs = [] Qs = [] for i, t in enumerate(timesteps_inference): if i in bp_idx: context = nullcontext() store_attention_maps = 'align' else: context = torch.inference_mode() store_attention_maps = None with context: timestep = t.expand(latents.shape[0]).to(latents.dtype) noise_pred = model( x_t=latents, t=timestep/1000, clip_images=batch['clip_images'], body_clip_images=body_clip_images, prompt_embeds=prompt_embeds, pooled_prompt_embeds=pooled_prompt_embeds, text_ids=text_ids, latent_image_ids=latent_image_ids, guidance_vec=guidance, arch=arch, arch_body=arch_body, vae_embeds=body_vae_embeds, accelerator=accelerator, weight_dtype=weight_dtype, args=args, drop_image_embeds=batch["drop_image_embeds"], face_embeddings=batch["face_embeddings"], clip_images_ori=batch["clip_images_ori"], subject_embedding=subject_embeddings, ip_scale=ip_scale, store_attention_maps = store_attention_maps, lora_scale=lora_scale, ) latents = noise_scheduler.step(noise_pred, t, latents, return_dict=False)[0] if i in bp_idx: current_KQQs = {k: v for k, v in model.KQQs.items()} current_Qs = {k: v for k, v in model.Qs.items()} KQQs.append(current_KQQs) Qs.append(current_Qs) if out_image: latents = rearrange( latents, "b (h w) (c ph pw) -> b c (h ph) (w pw)", h=image_height // 16, w=image_width // 16, ph=2, pw=2, ) latents = (latents / taef1.config.scaling_factor) + taef1.config.shift_factor image = taef1.decode(latents.to(weight_dtype), return_dict=False)[0] else: image = None return KQQs, Qs, image也就是说,训练中只有第一阶段使用到了sdxl,在其余阶段都是用lightning模型取而代之的?
你细读一下论文,Loss为
我个人实现的时候 ,对于第一项,我使用的是原模型(在我这里是flux),而后两个是带了加速lora的flux。 为什么这样可以训,因为lora对于flux也就是一个额外的组件,梯度是可以正常传到要被训的adapter的两个linear层的
No description provided.
请问有解决这个问题么?很奇怪我的第二阶段模型不收敛。我是在sd3尝试的方法,第二阶段我在原有的基础上加入了hypersd,这里从纯噪声开始分4步生成图像。图像生成过程中IDformer和crosstention模块都有梯度,vae解码张量也会保留梯度,后续的arcface特征提取也会有梯度。IDformer和crosstention模块是嵌入在hypersd中的,所以我的id loss会在hypersd分支中用梯度更新IDformer和crosstention模块
想请问一下这篇论文在训练的时候是只通过SDXL_lightning训练是吗,并不是SDXL和SDXL_lightning联合训练,因为我对于论文的多分支训练有些搞不清楚,所以问一问尝试复现的人
可以理解为lightning是一个插件,只是在算id loss的时候加进来,diffusion loss我还是在原来的模型上算的。我发一部分代码可以参考一下
model_pred = model( x_t=x_t, t=t, clip_images=batch['clip_images'], prompt_embeds=prompt_embeds, pooled_prompt_embeds=pooled_prompt_embeds, text_ids=text_ids, latent_image_ids=latent_image_ids, guidance_vec=guidance_vec, arch=arch, accelerator=accelerator, weight_dtype=weight_dtype, args=args, drop_image_embeds=batch["drop_image_embeds"], face_embeddings=batch["face_embeddings"], clip_images_ori=batch["clip_images_ori"], ip_scale=ip_scale, lora_scale=0.0, ) loss = F.mse_loss(model_pred.float(), (x_0 - x_1).float(), reduction="mean") seed = random.randint(0, 2**32 - 1) bp_ts = args.bp_ts num_inference_steps = 8 sample_method = 'random' if sample_method == 'random': bp_idx = random.sample(range(num_inference_steps - bp_ts), bp_ts) else: # 'last' bp_idx = [num_inference_steps - bp_ts] # Sample time steps for inference with torch.inference_mode(): KQQs_t2i_temp, Qs_t2i_temp, image_t2i = align_t2i( args, model, batch, x_t, prompt_embeds, pooled_prompt_embeds, text_ids, arch, arch_body, body_vae_embeds, subject_embeddings, body_clip_images, accelerator, noise_scheduler, taef1, weight_dtype, ip_scale=0.0, seed=seed, out_image=True, bp_ts=bp_ts, bp_idx=bp_idx, lora_scale=1.0, num_inference_steps=num_inference_steps, image_height=args.resolution[0], image_width=args.resolution[1], ) KQQs_t2i = [{k: v.detach().clone() for k, v in kqq.items()} for kqq in KQQs_t2i_temp] Qs_t2i = [{k: v.detach().clone() for k, v in q.items()} for q in Qs_t2i_temp] KQQs_ipa, Qs_ipa, image_ipa = align_t2i( args, model, batch, x_t, prompt_embeds, pooled_prompt_embeds, text_ids, arch, arch_body, body_vae_embeds, subject_embeddings, body_clip_images, accelerator, noise_scheduler, taef1, weight_dtype, ip_scale=1.0, seed=seed, out_image=True, bp_ts=bp_ts, bp_idx=bp_idx, lora_scale=1.0, num_inference_steps=num_inference_steps, image_height=args.resolution[0], image_width=args.resolution[1], ) loss_align_semantic = 0.0 loss_align_layout = 0.0 for step in range(bp_ts): loss_align_semantic += calcu_loss_align(KQQs_ipa[step], KQQs_t2i[step]) loss_align_layout += calcu_loss_align(Qs_t2i[step], Qs_ipa[step]) loss_align_semantic = loss_align_semantic / bp_ts loss_align_layout = loss_align_layout / bp_ts loss_dict = criterion_id.forward(image_ipa, pixel_values) loss_face = loss_dict['loss_id'] loss = loss + args.id_loss_weight * loss_face + args.look_align_layout_weight * loss_align_layout + args.look_align_semantic_weight * loss_align_semantic def align_t2i( args, model, batch, x_t, prompt_embeds, pooled_prompt_embeds, text_ids, arch, arch_body, body_vae_embeds, subject_embeddings, body_clip_images, accelerator, noise_scheduler, taef1, weight_dtype, ip_scale, seed, num_inference_steps=8, bp_ts=1, out_image=True, store_attention_maps='align', lora_scale=1.0, bp_idx=None, image_height=512, image_width=512, ): # Sample time steps for inference generator = torch.Generator(device=accelerator.device).manual_seed(seed) bs = x_t.shape[0] # height = 2 * math.ceil(batch['images'].shape[2] // 16) # width = 2 * math.ceil(batch['images'].shape[3] // 16) height = 2 * image_height // 16 width = 2 * image_width // 16 latents = torch.randn( bs, 16, height, width, device=accelerator.device, dtype=x_t.dtype, generator=generator, ) latents = rearrange(latents, "b c (h ph) (w pw) -> b (h w) (c ph pw)", ph=2, pw=2) latent_image_ids = FluxPipeline._prepare_latent_image_ids(bs, height // 2, width // 2, accelerator.device, x_t.dtype) sigmas = np.linspace(1.0, 1 / num_inference_steps, num_inference_steps) image_seq_len = latents.shape[1] mu = calculate_shift( image_seq_len, noise_scheduler.config.base_image_seq_len, noise_scheduler.config.max_image_seq_len, noise_scheduler.config.base_shift, noise_scheduler.config.max_shift, ) timesteps_inference, num_inference_steps = retrieve_timesteps( noise_scheduler, num_inference_steps, accelerator.device, None, sigmas, mu=mu, ) guidance = torch.full([1], 4, device=accelerator.device, dtype=torch.float32) guidance = guidance.expand(x_t.shape[0]) KQQs = [] Qs = [] for i, t in enumerate(timesteps_inference): if i in bp_idx: context = nullcontext() store_attention_maps = 'align' else: context = torch.inference_mode() store_attention_maps = None with context: timestep = t.expand(latents.shape[0]).to(latents.dtype) noise_pred = model( x_t=latents, t=timestep/1000, clip_images=batch['clip_images'], body_clip_images=body_clip_images, prompt_embeds=prompt_embeds, pooled_prompt_embeds=pooled_prompt_embeds, text_ids=text_ids, latent_image_ids=latent_image_ids, guidance_vec=guidance, arch=arch, arch_body=arch_body, vae_embeds=body_vae_embeds, accelerator=accelerator, weight_dtype=weight_dtype, args=args, drop_image_embeds=batch["drop_image_embeds"], face_embeddings=batch["face_embeddings"], clip_images_ori=batch["clip_images_ori"], subject_embedding=subject_embeddings, ip_scale=ip_scale, store_attention_maps = store_attention_maps, lora_scale=lora_scale, ) latents = noise_scheduler.step(noise_pred, t, latents, return_dict=False)[0] if i in bp_idx: current_KQQs = {k: v for k, v in model.KQQs.items()} current_Qs = {k: v for k, v in model.Qs.items()} KQQs.append(current_KQQs) Qs.append(current_Qs) if out_image: latents = rearrange( latents, "b (h w) (c ph pw) -> b c (h ph) (w pw)", h=image_height // 16, w=image_width // 16, ph=2, pw=2, ) latents = (latents / taef1.config.scaling_factor) + taef1.config.shift_factor image = taef1.decode(latents.to(weight_dtype), return_dict=False)[0] else: image = None return KQQs, Qs, image也就是说,训练中只有第一阶段使用到了sdxl,在其余阶段都是用lightning模型取而代之的?
你细读一下论文,Loss为
我个人实现的时候 ,对于第一项,我使用的是原模型(在我这里是flux),而后两个是带了加速lora的flux。 为什么这样可以训,因为lora对于flux也就是一个额外的组件,梯度是可以正常传到要被训的adapter的两个linear层的
你好可以请问一下在flux模型做复现训练的时候应该选用什么作为计算alignment loss计算时的K和Q吗,在flux模型中和论文里对应的textual features和"UNet features"好像不是很明确@xilanhua12138
你细读一下论文,Loss为
我个人实现的时候 ,对于第一项,我使用的是原模型(在我这里是flux),而后两个是带了加速lora的flux。 为什么这样可以训,因为lora对于flux也就是一个额外的组件,梯度是可以正常传到要被训的adapter的两个linear层的
你好可以请问一下在flux模型做复现训练的时候应该选用什么作为计算alignment loss计算时的K和Q吗,在flux模型中和论文里对应的textual features和"UNet features"好像不是很明确@xilanhua12138
其实很明确,flux 也是分图像模态和语言模态的,所以 Q 应该是图像模态 K 是语言模态
你细读一下论文,Loss为
我个人实现的时候 ,对于第一项,我使用的是原模型(在我这里是flux),而后两个是带了加速lora的flux。 为什么这样可以训,因为lora对于flux也就是一个额外的组件,梯度是可以正常传到要被训的adapter的两个linear层的
你好可以请问一下在flux模型做复现训练的时候应该选用什么作为计算alignment loss计算时的K和Q吗,在flux模型中和论文里对应的textual features和"UNet features"好像不是很明确@xilanhua12138
其实很明确,flux 也是分图像模态和语言模态的,所以 Q 应该是图像模态 K 是语言模态
从flux的推理代码来说有区分Double & Single Stream Block,我自己的判断是使用img, txt = block(img=img, txt=txt, vec=vec, pe=pe)计算得到的txt作为K,以及img(double stream)/real_img(single stream) = img + id_weight * self.pulid_ca[ca_idx](id, img)的img/real_img作为Q,不确定是否可行呢@xilanhua12138
我个人实现的时候 ,对于第一项,我使用的是原模型(在我这里是flux),而后两个是带了加速lora的flux。 为什么这样可以训,因为lora对于flux也就是一个额外的组件,梯度是可以正常传到要被训的adapter的两个linear层的