peft icon indicating copy to clipboard operation
peft copied to clipboard

T5 with PrefixTuning ,error

Open yuyijiong opened this issue 2 years ago • 2 comments

import torch
from transformers import AutoModelForSeq2SeqLM,T5Tokenizer
from peft import get_peft_config, get_peft_model, TaskType,PrefixTuningConfig,PeftModelForSeq2SeqLM,PeftModel
model_name_or_path = "t5-small"
tokenizer_name_or_path = "t5-small"
model = AutoModelForSeq2SeqLM.from_pretrained(model_name_or_path)
tokenizer = T5Tokenizer.from_pretrained(tokenizer_name_or_path)

peft_config=PrefixTuningConfig(
    task_type=TaskType.SEQ_2_SEQ_LM,
    inference_mode=False,
    num_virtual_tokens=20)

model = get_peft_model(model, peft_config)

model.cuda()
input_ids = tokenizer.encode("Is dog an animal?", return_tensors="pt").to(model.device)
labels = tokenizer.encode("yes", return_tensors="pt").to(model.device)
decoder_input_ids = labels[:, :-1].contiguous().to(model.device)
labels = labels[:, 1:].clone()
outputs = model(input_ids=input_ids, labels=labels, decoder_input_ids=decoder_input_ids)
Traceback (most recent call last):
  File "/home/datamining/miniconda3/envs/lxl/lib/python3.8/site-packages/IPython/core/interactiveshell.py", line 3433, in run_code
    exec(code_obj, self.user_global_ns, self.user_ns)
  File "<ipython-input-2-d8dc239897c0>", line 24, in <module>
    outputs = model(input_ids=input_ids, labels=labels, decoder_input_ids=decoder_input_ids)
  File "/home/datamining/miniconda3/envs/lxl/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1482, in _call_impl
    return forward_call(*args, **kwargs)
  File "/home/datamining/miniconda3/envs/lxl/lib/python3.8/site-packages/peft/peft_model.py", line 676, in forward
    return self.base_model(
  File "/home/datamining/miniconda3/envs/lxl/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1482, in _call_impl
    return forward_call(*args, **kwargs)
  File "/home/datamining/miniconda3/envs/lxl/lib/python3.8/site-packages/transformers/models/t5/modeling_t5.py", line 1648, in forward
    decoder_outputs = self.decoder(
  File "/home/datamining/miniconda3/envs/lxl/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1482, in _call_impl
    return forward_call(*args, **kwargs)
  File "/home/datamining/miniconda3/envs/lxl/lib/python3.8/site-packages/transformers/models/t5/modeling_t5.py", line 1040, in forward
    layer_outputs = layer_module(
  File "/home/datamining/miniconda3/envs/lxl/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1482, in _call_impl
    return forward_call(*args, **kwargs)
  File "/home/datamining/miniconda3/envs/lxl/lib/python3.8/site-packages/transformers/models/t5/modeling_t5.py", line 699, in forward
    cross_attention_outputs = self.layer[1](
  File "/home/datamining/miniconda3/envs/lxl/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1482, in _call_impl
    return forward_call(*args, **kwargs)
  File "/home/datamining/miniconda3/envs/lxl/lib/python3.8/site-packages/transformers/models/t5/modeling_t5.py", line 613, in forward
    attention_output = self.EncDecAttention(
  File "/home/datamining/miniconda3/envs/lxl/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1482, in _call_impl
    return forward_call(*args, **kwargs)
  File "/home/datamining/miniconda3/envs/lxl/lib/python3.8/site-packages/transformers/models/t5/modeling_t5.py", line 538, in forward
    scores += position_bias_masked
RuntimeError: The size of tensor a (20) must match the size of tensor b (7) at non-singleton dimension 3

T5 with PrefixTuning will cause errors. It possibly caused by past_key_values' shape.

yuyijiong avatar Feb 12 '23 13:02 yuyijiong

Hello @yuyijiong, I'm able to run using t5-small without any issues: https://github.com/pacman100/temp/blob/main/peft_prefix_tuning_seq2seq%20(1).ipynb

pacman100 avatar Feb 12 '23 14:02 pacman100

The above code from issue description also works fine. Could you use the latest versions of Transformers, Accelerate and PEFT and mention them:

Screenshot 2023-02-12 at 8 05 42 PM

pacman100 avatar Feb 12 '23 14:02 pacman100

I upgrade transformers from 4.25.1 to 4.26.1, then this error disappeared. That means peft need transformers==4.26.1 Thank you.

yuyijiong avatar Feb 13 '23 09:02 yuyijiong

Great! Will update the setup.py accordingly, Thank You! Feel free to close the issue.

pacman100 avatar Feb 13 '23 09:02 pacman100