Efficient Model Loading
Hello, I'm trying to LoRA finetune a 32B model using the lightning interface described in python-api.md and I'm facing some issues with the model checkpoint
FSDP:
def configure_model(self):
if self.model is not None:
return
self.model = GPT.from_name(
name=self.model_name,
lora_r=self.lora_r,
lora_alpha=self.lora_alpha,
lora_dropout=self.lora_dropout,
lora_key=self.lora_key,
lora_value=self.lora_value,
lora_query=self.lora_query,
)
self.load_checkpoint()
self.configure_head()
make_only_lora_head_as_trainable(self)
and
def load_checkpoint(self):
if self.checkpoint_path is None:
checkpoint_dir = Path("/opt/dlami/nvme") / "models"
self.checkpoint_path = checkpoint_dir / self.model_name / "lit_model.pth"
self.checkpoint_path.parent.mkdir(parents=True, exist_ok=True)
if not self.checkpoint_path.exists():
download_from_hub(repo_id=self.model_name, checkpoint_dir=checkpoint_dir)
state_dict = torch.load(self.checkpoint_path, weights_only=True)
self.model.load_state_dict(state_dict, strict=False, assign=False)
When I use this on 8 A100s, I run out of CPU RAM (1 TB). However, when I use fewer GPUs, it is very very slow to initialise the model (it doesn't do the fast lazy init for some reason) but it moves forward (and dies at the forward prop due to OOM).
So, I wrapped the model creation with
self.trainer.init_context(empty_init=True)
and then it moves forward, but crashes with the error that FSDP serialization needs everything in the same data type.
Deepspeed
I tried switching to Deepspeed but the model loading fails as the model is sharded before the weights are loaded.
Question
Can you provide some hints how to fix this initialization issues? Can you please provide a way to init very large models with pretrained weights for training using lightning?
Thank you! Prabhu
Hi @prabhuteja12, Can you share the model details and error stack traces to help debug the problem?
Hi @raishish
I face this when I use Qwen2.5 Coder 32B, though I don't think this is related to the model.
Reg stack trace: there really isn't much more than what I wrote there. Can I ask you for a link to how to initialize/fine-tune large (>10B) litgpt models? The examples often correspond to smaller models, where most modern GPUs don't have an issue with memory.