Question about saving peft checkpoint
🐛 Describe the bug
From my understand, when saving checkpoints for peft models (see here), trlx removes pytorch_model.bin before calling save_pretrained which makes the removal useless in my opinion.
Is this intentional or we should move the removal code after save_pretrained is called?
Here is an example of a directory resulting from save_pretrained:
adapter_config.json adapter_model.bin optimizer.bin pytorch_model.bin random_states_0.pkl special_tokens_map.json spiece.model tokenizer_config.json tokenizer.json
Which trlX version are you using?
0.7.0
Additional system and package information
3.10.12
Hello @nhanph!
No, the removal is not useless. If you check contents of python_model.bin immediately after this line: https://github.com/CarperAI/trlx/blob/bcd237f1e94c84c5c9f5a4086bab34c0946e3fa7/trlx/trainer/accelerate_base_trainer.py#L309 you will see that it contains the whole model state dictionary, which is not needed. After deleting it and recreating it with model.save_pretrained with heads_only=True, only value heads will be kept there.
>>> list(before_deletion_state_dict.keys())[:32]
['v_head.0.weight', 'v_head.0.bias', 'v_head.2.weight', 'v_head.2.bias', 'base_model.model.transformer.wte.weight', 'base_model.model.transformer.wpe.weight', 'base_model.model.transformer.h.0.ln_1.weight', 'base_model.model.transformer.h.0.ln_1.bias', 'base_model.model.transformer.h.0.attn.c_attn.weight', 'base_model.model.transformer.h.0.attn.c_attn.bias', 'base_model.model.transformer.h.0.attn.c_attn.lora_A.default.weight', 'base_model.model.transformer.h.0.attn.c_attn.lora_B.default.weight', 'base_model.model.transformer.h.0.attn.c_proj.weight', 'base_model.model.transformer.h.0.attn.c_proj.bias', 'base_model.model.transformer.h.0.ln_2.weight', 'base_model.model.transformer.h.0.ln_2.bias', 'base_model.model.transformer.h.0.mlp.c_fc.weight', 'base_model.model.transformer.h.0.mlp.c_fc.bias', 'base_model.model.transformer.h.0.mlp.c_proj.weight', 'base_model.model.transformer.h.0.mlp.c_proj.bias', 'base_model.model.transformer.h.1.ln_1.weight', 'base_model.model.transformer.h.1.ln_1.bias', 'base_model.model.transformer.h.1.attn.c_attn.weight', 'base_model.model.transformer.h.1.attn.c_attn.bias', 'base_model.model.transformer.h.1.attn.c_attn.lora_A.default.weight', 'base_model.model.transformer.h.1.attn.c_attn.lora_B.default.weight', 'base_model.model.transformer.h.1.attn.c_proj.weight', 'base_model.model.transformer.h.1.attn.c_proj.bias', 'base_model.model.transformer.h.1.ln_2.weight', 'base_model.model.transformer.h.1.ln_2.bias', 'base_model.model.transformer.h.1.mlp.c_fc.weight', 'base_model.model.transformer.h.1.mlp.c_fc.bias']
>>> list(after_save_pretrained_state_dict.keys())[:32]
['v_head.0.weight', 'v_head.0.bias', 'v_head.2.weight', 'v_head.2.bias']
Thank you @maxreciprocate , I got the point about saving model's value head now.
My original question is from my observation when running ILQL training script that I see a pytorch_model.bin with the size comparable with the original model so I suspect that the base model is also saved. Is there somewhere that the heads_only flag is set to true during checkpointing when using peft_config as I cannot find it set anywhere?