WanX 14B I2V "Cannot load LoRA"
After I run lora training for WanX 14B I2V, I got checkpoints like
then I run zero_to_fp32.py to get pytorch_model.bin file. But when I try to load it to test lora, it fails to match any lora type and the match_results for lora.match(model, state_dict) returns None.
I am wondering how exactly to test the lora results?
by using this config without deepseed ,will generate only one *.ckpt in output folder:
python examples/wanvideo/train_wan_t2v.py
--task train
--train_architecture lora
--dataset_path data/test
--output_path ./output
--dit_path "models/Wan-AI/Wan2.1-T2V-14B/diffusion_pytorch_model-00001-of-00006.safetensors,models/Wan-AI/Wan2.1-T2V-14B/diffusion_pytorch_model-00002-of-00006.safetensors,models/Wan-AI/Wan2.1-T2V-14B/diffusion_pytorch_model-00003-of-00006.safetensors,models/Wan-AI/Wan2.1-T2V-14B/diffusion_pytorch_model-00004-of-00006.safetensors,models/Wan-AI/Wan2.1-T2V-14B/diffusion_pytorch_model-00005-of-00006.safetensors,models/Wan-AI/Wan2.1-T2V-14B/diffusion_pytorch_model-00006-of-00006.safetensors"
--steps_per_epoch 500
--max_epochs 10
--learning_rate 1e-4
--lora_rank 16
--lora_alpha 16
--lora_target_modules "q,k,v,o,ffn.0,ffn.2"
--use_gradient_checkpointing
--use_gradient_checkpointing_offload
when load lora, it will show like this: Adding LoRA to wan_video_dit (['models/Wan-AI/Wan2.1-T2V-14B/diffusion_pytorch_model-00001-of-00006.safetensors', 'models/Wan-AI/Wan2.1-T2V-14B/diffusion_pytorch_model-00002-of-00006.safetensors', 'models/Wan-AI/Wan2.1-T2V-14B/diffusion_pytorch_model-00003-of-00006.safetensors', 'models/Wan-AI/Wan2.1-T2V-14B/diffusion_pytorch_model-00004-of-00006.safetensors', 'models/Wan-AI/Wan2.1-T2V-14B/diffusion_pytorch_model-00005-of-00006.safetensors', 'models/Wan-AI/Wan2.1-T2V-14B/diffusion_pytorch_model-00006-of-00006.safetensors']).
but the models were saved in HDD, it will spend much time in loading and the program was automatically killed by the system when i first time load lora
@qidai2000 This is a bug in deepspeed. Some extra parameters are stored in the checkpoint files. You can extract the lora state dict using the following code.
from diffsynth import load_state_dict
import torch
state_dict = load_state_dict("xxx.pth")
state_dict = {i: state_dict[i] for i in state_dict if i.endswith(".lora_A.default.weight") or i.endswith(".lora_B.default.weight")}
torch.save("yyy.pth")
@qidai2000 This is a bug in deepspeed. Some extra parameters are stored in the checkpoint files. You can extract the lora state dict using the following code.
from diffsynth import load_state_dict import torch
state_dict = load_state_dict("xxx.pth") state_dict = {i: state_dict[i] for i in state_dict if i.endswith(".lora_A.default.weight") or i.endswith(".lora_B.default.weight")} torch.save("yyy.pth")
Thank you for your reply, but it doesn't work :( I have checked the state_dict keys, all are ended with lora weight, but it still cannot match any lora type.
And for fine-tune, the problem remains the same. The model type of the finetuned models fail to be detected and "No models are loaded".
anyone figure this out? I can't convert to comfy friendly format -- lora weights aren't loaded properly
Change in models.lora.GeneralLoRAFromPeft:
def load(self, model, state_dict_lora, lora_prefix="", alpha=1.0, model_resource=""):
state_dict_model = model.state_dict()
device, dtype, computation_device, computation_dtype = self.fetch_device_and_dtype(state_dict_model)
lora_name_dict = self.get_name_dict(state_dict_lora)
for name in lora_name_dict:
weight_up = state_dict_lora[lora_name_dict[name][0]].to(device=computation_device, dtype=computation_dtype)
weight_down = state_dict_lora[lora_name_dict[name][1]].to(device=computation_device, dtype=computation_dtype)
if len(weight_up.shape) == 4:
weight_up = weight_up.squeeze(3).squeeze(2)
weight_down = weight_down.squeeze(3).squeeze(2)
weight_lora = alpha * torch.mm(weight_up, weight_down).unsqueeze(2).unsqueeze(3)
else:
weight_lora = alpha * torch.mm(weight_up, weight_down)
weight_model = state_dict_model[name.replace("pipe.dit.","")].to(device=computation_device, dtype=computation_dtype)
weight_patched = weight_model + weight_lora
state_dict_model[name.replace("pipe.dit.","")] = weight_patched.to(device=device, dtype=dtype)
print(f" {len(lora_name_dict)} tensors are updated.")
model.load_state_dict(state_dict_model)
def match(self, model: torch.nn.Module, state_dict_lora):
lora_name_dict = self.get_name_dict(state_dict_lora)
model_name_dict = {name: None for name, _ in model.named_parameters()}
matched_num = sum([i.replace("pipe.dit.","") in model_name_dict for i in lora_name_dict])
if matched_num == len(lora_name_dict):
return "", ""
else:
return None