litgpt icon indicating copy to clipboard operation
litgpt copied to clipboard

Efficient Model Loading

Open prabhuteja12 opened this issue 8 months ago • 2 comments

Hello, I'm trying to LoRA finetune a 32B model using the lightning interface described in python-api.md and I'm facing some issues with the model checkpoint

FSDP:


def configure_model(self):
  if self.model is not None:
       return 
  self.model = GPT.from_name(
                  name=self.model_name,
                  lora_r=self.lora_r,
                  lora_alpha=self.lora_alpha,
                  lora_dropout=self.lora_dropout,
                  lora_key=self.lora_key,
                  lora_value=self.lora_value,
                  lora_query=self.lora_query,
              )
  self.load_checkpoint()
  self.configure_head()
  make_only_lora_head_as_trainable(self)

and

    def load_checkpoint(self):
        if self.checkpoint_path is None:
            checkpoint_dir = Path("/opt/dlami/nvme") / "models"
            self.checkpoint_path = checkpoint_dir / self.model_name / "lit_model.pth"
            self.checkpoint_path.parent.mkdir(parents=True, exist_ok=True)
            if not self.checkpoint_path.exists():
                download_from_hub(repo_id=self.model_name, checkpoint_dir=checkpoint_dir)
        state_dict = torch.load(self.checkpoint_path, weights_only=True)
        self.model.load_state_dict(state_dict, strict=False, assign=False)

When I use this on 8 A100s, I run out of CPU RAM (1 TB). However, when I use fewer GPUs, it is very very slow to initialise the model (it doesn't do the fast lazy init for some reason) but it moves forward (and dies at the forward prop due to OOM).

So, I wrapped the model creation with self.trainer.init_context(empty_init=True) and then it moves forward, but crashes with the error that FSDP serialization needs everything in the same data type.

Deepspeed

I tried switching to Deepspeed but the model loading fails as the model is sharded before the weights are loaded.

Question

Can you provide some hints how to fix this initialization issues? Can you please provide a way to init very large models with pretrained weights for training using lightning?

Thank you! Prabhu

prabhuteja12 avatar Aug 05 '25 21:08 prabhuteja12

Hi @prabhuteja12, Can you share the model details and error stack traces to help debug the problem?

raishish avatar Aug 07 '25 16:08 raishish

Hi @raishish

I face this when I use Qwen2.5 Coder 32B, though I don't think this is related to the model.

Reg stack trace: there really isn't much more than what I wrote there. Can I ask you for a link to how to initialize/fine-tune large (>10B) litgpt models? The examples often correspond to smaller models, where most modern GPUs don't have an issue with memory.

prabhuteja12 avatar Aug 08 '25 08:08 prabhuteja12