direct-preference-optimization icon indicating copy to clipboard operation
direct-preference-optimization copied to clipboard

RuntimeError: Error(s) in loading state_dict on a custom model

Open gopstrit opened this issue 2 years ago • 11 comments

Hi,

We'are trying to use DPO to align our fine-tuned Falcon-7B model. We actually already have fine-tuned falcon-7b model, so I used SFT on a small sample ~1000 observations, and generated policy checkpoint required for DPO.

I've created the config/model file and also specified module = ['query_key_value'] for LoRA.

When performing DPO, I get the error (RuntimeError: Error(s) in loading state_dict) at the following specific line:

image

This is the error that I get: image image image image image image

And, here is my model architecture image

I went ahead and added strict=False on that line: i.e. reference_model.load_state_dict(state_dict['state'], strict=False) - which sort of solves [rather hides?] the error.

**After DPO finishes, I went ahead to load the state_dict to the model, where I encountered similar error:

image

But again, if I do: non_aligned_model.load_state_dict(torch.load(model_archive_name), strict=False) it kind of avoids the error, but still throws a long list of incompatible keys (butI can go ahead and do the inference)

image

For this: I have run the basic trainer DPO as:

ulimit -n 64000; python -u train_qlora.py model=falcon7b datasets=[hh] loss=dpo loss.beta=0.1 model.archive=/ipython-docker/notebook_dir/data/DPO/temp_cache_gopal/dpo/custom_falcon_sft_2023-08-09_13-49-14_625672/LATEST/policy.pt exp_name=custom_falcon_dpo gradient_accumulation_steps=4 batch_size=4 eval_batch_size=8

**Question about the above output **

  • What do you think is the issue here? And how can I possibly solve it?

Another question about FSDP Trainer

  • I also tried FSDP trainer - which didn't work for me (or perhaps, I didn't set it well) For this I used the following launch command- ulimit -n 64000; python -u train_qlora.py model=falcon7b datasets=[hh] loss=dpo loss.beta=0.1 model.archive=/ipython-docker/notebook_dir/data/DPO/temp_cache_gopal/dpo/army_falcon_sft_2023-08-09_13-49-14_625672/LATEST/policy.pt exp_name=army_falcon_dpo gradient_accumulation_steps=2 batch_size=2 eval_batch_size=2 model.fsdp_policy_mp=bfloat16 trainer=FSDPTrainer sample_during_eval=false

and I had used block_name = DecoderLayer

This is the error that I got: image

** @eric-mitchell Would appreciate your guidance on this. Thank you. (Btw, very cool implementation of DPO :) )**

gopstrit avatar Aug 10 '23 04:08 gopstrit

To start, it looks like the checkpoint for your weights includes a wrapper base_model.model. in front of each parameter name, so PyTorch can't find the parameters it needs. I assume this is because you saved the checkpoint with a wrapper for e.g. LoRA, but the reference model you've built doesn't have the LoRA wrapper applied. You should instantiate the reference model the same way you are instantiating the policy (since at the first train step, we should have policy = reference_model).

eric-mitchell avatar Aug 10 '23 07:08 eric-mitchell

To start, it looks like the checkpoint for your weights includes a wrapper base_model.model. in front of each parameter name, so PyTorch can't find the parameters it needs. I assume this is because you saved the checkpoint with a wrapper for e.g. LoRA, but the reference model you've built doesn't have the LoRA wrapper applied. You should instantiate the reference model the same way you are instantiating the policy (since at the first train step, we should have policy = reference_model).

Hi @eric-mitchell Thanks a lot for your reply.

Two quick questions: -1 ) Do you think my approach of SFT on ~1000 sample points for "already fine-tuned Falcon model" is appropriate? The reason why I did was that I see that DPO requires checkpoint - and my fine-tuned Falcon model (which was fine-tuned separately) only has .bin files.

-2) About the specific problem above: I simply used the existing code in "train_qlora", and it automatically saved the checkpoint. All I did was copy and paste the latest "folder_path/policy.pt" when running dpo. How do you suggest I solve the issue? Didn't quite get it when you said I need to instantiate the reference model the same way as the policy.

image

gopstrit avatar Aug 10 '23 07:08 gopstrit

Ah, are you suggesting, I need to do this for reference model as well?

image

gopstrit avatar Aug 10 '23 07:08 gopstrit

Okay, I added this, and it seems to have worked! (Will update if I get through the whole process).

image

Another quick question: I saw in one of the posts that you suggest to load the model as (after DPO concludes):

model = AutoModelForCausalLM.from_pretrained(model_name, torch_dtype=torch.bfloat16, trust_remote_code=True, device_map="auto") model_archive_name = "/ipython-docker/notebook_dir/data/DPO/temp_cache_gopal/dpo/custom_falcon_dpo_2023-08-09_17-52-40_092080/LATEST/policy.pt"

model.load_state_dict(torch.load(model_archive_name))

How would you suggest I save the final complete model?

Thanks a lot @eric-mitchell

gopstrit avatar Aug 10 '23 07:08 gopstrit

@eric-mitchell I finished DPO, but when merging policy.pt, got the following error: image image image

Am I doing something wrong in here?

Another issue:

  • I have around 80 k samples in my training dataset, but the DPO completed around ~33000. Why is that so? image

gopstrit avatar Aug 10 '23 20:08 gopstrit

Sounds like you worked the lora part out! For loading the new checkpoint, the issue is that you need to load torch.load(model_archive_name)['state'], since the archived parameters are in the 'state' sub-dictionary, not the top-level dictionary.

eric-mitchell avatar Aug 11 '23 00:08 eric-mitchell

Sounds like you worked the lora part out! For loading the new checkpoint, the issue is that you need to load torch.load(model_archive_name)['state'], since the archived parameters are in the 'state' sub-dictionary, not the top-level dictionary.

Thanks @eric-mitchell for your great support. :) I'll do that and will quickly update in here.

About the second question, why do you think the process concludes at ~32000 samples (I noticed the same yesterday too when I was running it on ~63000 samples; this one's run on ~80000 samples.)

gopstrit avatar Aug 11 '23 00:08 gopstrit

Are you passing an argument for n_examples or n_epochs? Can you check in preference_datasets.py when you create the dataset, are you sure there are 80k preference pairs? Not sure why training would stop early if there are actually 80k pairs in the data iterator.

eric-mitchell avatar Aug 11 '23 00:08 eric-mitchell

I'm using the same code as get_hh(), as my data is of same structure.

When it downloads data from huggingface initially, it shows the same number of training and test datasets that I have. But it just concludes at ~32000.

No. of epoch is set at 1.

gopstrit avatar Aug 11 '23 01:08 gopstrit

Sounds like you worked the lora part out! For loading the new checkpoint, the issue is that you need to load torch.load(model_archive_name)['state'], since the archived parameters are in the 'state' sub-dictionary, not the top-level dictionary.

Also, do you suggest getting the peft model before loading the state? [I gave it a quick try without it, which gave me an error] - but when I applied get_peft_model(pre-trained model, loraconfig), it worked (Below)

image

But the resulting model (after I saved the model) is way smaller in size (I'm assuming I need to merge in the resulting adapters back to the main model)

Here's the model architecture after the following step; image

image image image image

gopstrit avatar Aug 11 '23 01:08 gopstrit

@eric-mitchell

I have figured out where it reduces the training example size.

In the following section in preference_datasets.py image

You can see that I've 349 examples in training dataset: but it gets reduced to 139 afterwards image

Does this look concerning?

gopstrit avatar Aug 13 '23 05:08 gopstrit