DiffSynth-Studio icon indicating copy to clipboard operation
DiffSynth-Studio copied to clipboard

WanX 14B I2V "Cannot load LoRA"

Open qidai2000 opened this issue 10 months ago • 6 comments

After I run lora training for WanX 14B I2V, I got checkpoints like Image

then I run zero_to_fp32.py to get pytorch_model.bin file. But when I try to load it to test lora, it fails to match any lora type and the match_results for lora.match(model, state_dict) returns None.

I am wondering how exactly to test the lora results?

qidai2000 avatar Mar 26 '25 07:03 qidai2000

by using this config without deepseed ,will generate only one *.ckpt in output folder: python examples/wanvideo/train_wan_t2v.py
--task train
--train_architecture lora
--dataset_path data/test
--output_path ./output
--dit_path "models/Wan-AI/Wan2.1-T2V-14B/diffusion_pytorch_model-00001-of-00006.safetensors,models/Wan-AI/Wan2.1-T2V-14B/diffusion_pytorch_model-00002-of-00006.safetensors,models/Wan-AI/Wan2.1-T2V-14B/diffusion_pytorch_model-00003-of-00006.safetensors,models/Wan-AI/Wan2.1-T2V-14B/diffusion_pytorch_model-00004-of-00006.safetensors,models/Wan-AI/Wan2.1-T2V-14B/diffusion_pytorch_model-00005-of-00006.safetensors,models/Wan-AI/Wan2.1-T2V-14B/diffusion_pytorch_model-00006-of-00006.safetensors"
--steps_per_epoch 500
--max_epochs 10
--learning_rate 1e-4
--lora_rank 16
--lora_alpha 16
--lora_target_modules "q,k,v,o,ffn.0,ffn.2"
--use_gradient_checkpointing
--use_gradient_checkpointing_offload

when load lora, it will show like this: Adding LoRA to wan_video_dit (['models/Wan-AI/Wan2.1-T2V-14B/diffusion_pytorch_model-00001-of-00006.safetensors', 'models/Wan-AI/Wan2.1-T2V-14B/diffusion_pytorch_model-00002-of-00006.safetensors', 'models/Wan-AI/Wan2.1-T2V-14B/diffusion_pytorch_model-00003-of-00006.safetensors', 'models/Wan-AI/Wan2.1-T2V-14B/diffusion_pytorch_model-00004-of-00006.safetensors', 'models/Wan-AI/Wan2.1-T2V-14B/diffusion_pytorch_model-00005-of-00006.safetensors', 'models/Wan-AI/Wan2.1-T2V-14B/diffusion_pytorch_model-00006-of-00006.safetensors']).

but the models were saved in HDD, it will spend much time in loading and the program was automatically killed by the system when i first time load lora

meareabc avatar Mar 26 '25 18:03 meareabc

@qidai2000 This is a bug in deepspeed. Some extra parameters are stored in the checkpoint files. You can extract the lora state dict using the following code.

from diffsynth import load_state_dict
import torch

state_dict = load_state_dict("xxx.pth")
state_dict = {i: state_dict[i] for i in state_dict if i.endswith(".lora_A.default.weight") or i.endswith(".lora_B.default.weight")}
torch.save("yyy.pth")

Artiprocher avatar Mar 27 '25 03:03 Artiprocher

@qidai2000 This is a bug in deepspeed. Some extra parameters are stored in the checkpoint files. You can extract the lora state dict using the following code.

from diffsynth import load_state_dict import torch

state_dict = load_state_dict("xxx.pth") state_dict = {i: state_dict[i] for i in state_dict if i.endswith(".lora_A.default.weight") or i.endswith(".lora_B.default.weight")} torch.save("yyy.pth")

Thank you for your reply, but it doesn't work :( I have checked the state_dict keys, all are ended with lora weight, but it still cannot match any lora type.

qidai2000 avatar Mar 27 '25 07:03 qidai2000

And for fine-tune, the problem remains the same. The model type of the finetuned models fail to be detected and "No models are loaded".

qidai2000 avatar Mar 27 '25 08:03 qidai2000

anyone figure this out? I can't convert to comfy friendly format -- lora weights aren't loaded properly

aw93silverside avatar Apr 01 '25 15:04 aw93silverside

Change in models.lora.GeneralLoRAFromPeft:

def load(self, model, state_dict_lora, lora_prefix="", alpha=1.0, model_resource=""):
    state_dict_model = model.state_dict()
    device, dtype, computation_device, computation_dtype = self.fetch_device_and_dtype(state_dict_model)
    lora_name_dict = self.get_name_dict(state_dict_lora)
    for name in lora_name_dict:
        weight_up = state_dict_lora[lora_name_dict[name][0]].to(device=computation_device, dtype=computation_dtype)
        weight_down = state_dict_lora[lora_name_dict[name][1]].to(device=computation_device, dtype=computation_dtype)
        if len(weight_up.shape) == 4:
            weight_up = weight_up.squeeze(3).squeeze(2)
            weight_down = weight_down.squeeze(3).squeeze(2)
            weight_lora = alpha * torch.mm(weight_up, weight_down).unsqueeze(2).unsqueeze(3)
        else:
            weight_lora = alpha * torch.mm(weight_up, weight_down)
        weight_model = state_dict_model[name.replace("pipe.dit.","")].to(device=computation_device, dtype=computation_dtype)
        weight_patched = weight_model + weight_lora
        state_dict_model[name.replace("pipe.dit.","")] = weight_patched.to(device=device, dtype=dtype)
    print(f"    {len(lora_name_dict)} tensors are updated.")
    model.load_state_dict(state_dict_model)



def match(self, model: torch.nn.Module, state_dict_lora):
    lora_name_dict = self.get_name_dict(state_dict_lora)
    model_name_dict = {name: None for name, _ in model.named_parameters()}
    matched_num = sum([i.replace("pipe.dit.","") in model_name_dict for i in lora_name_dict])
    if matched_num == len(lora_name_dict):
        return "", ""
    else:
        return None

njzxj avatar Apr 02 '25 06:04 njzxj