DeepLearningExamples icon indicating copy to clipboard operation
DeepLearningExamples copied to clipboard

Adding LORA to Bert finetuning example

Open ankur6ue opened this issue 1 year ago • 0 comments

Hello, I took the code for finetuning the BERT model on the Squad dataset and refactored the code as follows:

  • Separated out feature generation in a separate file and added multi-processing
  • Added support for distributed data parallel during finetuning, using Pytorch DDP. Removed all references to apex.amp, which is now deprecated. My code is here: https://github.com/ankur6ue/distributed_training/tree/master

I'm now trying to add LORA to the code. I did the following: model = modeling.BertForQuestionAnswering(config)

    peft_config = LoraConfig(
        task_type=TaskType.QUESTION_ANS,
        inference_mode=False,
        r=16,
        lora_alpha=16,
        lora_dropout=0.1,
        bias="all", # will train the bias matrices.
        target_modules=["key", "query", "value"],
        modules_to_save=["qa_outputs"],
    )

model.load_state_dict(checkpoint, strict=False)

lora_model = get_peft_model(model, peft_config) lora_model.print_trainable_parameters() lora_model.to(device)

optimizer = AdamW(lora_model.parameters(), lr=args.learning_rate, # args.learning_rate - default is 5e-5, our notebook had 2e-5 eps=1e-8 # args.adam_epsilon - default is 1e-8. )

scheduler = get_linear_schedule_with_warmup(optimizer, num_warmup_steps=num_warmup_steps, # Default value in run_glue.py num_training_steps=num_train_optimization_steps)

In the training loop: start_logits, end_logits = lora_model(input_ids, segment_ids, input_mask)

The line above raises an exception: return_dict = return_dict if return_dict is not None else self.config.use_return_dict AttributeError: 'BertConfig' object has no attribute 'use_return_dict'

If I do the following: start_logits, end_logits = model(input_ids, segment_ids, input_mask)

Then the forward pass works. What's the right way to use peft on this BERT model?

ankur6ue avatar Jun 07 '24 21:06 ankur6ue