TypeError: LlamaRotaryEmbedding.forward() got an unexpected keyword argument 'seq_len'
🐛 Describe the bug
File "/data/llmodel/miniconda3/envs/colossal/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl return self._call_impl(*args, **kwargs) File "/data/llmodel/miniconda3/envs/colossal/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1527, in _call_impl return forward_call(*args, **kwargs) File "/data/llmodel/huap/ColossalAI/applications/Colossal-LLaMA-2/colossal_llama2/utils/flash_attention_patch.py", line 133, in attention_forward cos, sin = self.rotary_emb(v, seq_len=kv_len) File "/data/llmodel/miniconda3/envs/colossal/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl return self._call_impl(*args, **kwargs) File "/data/llmodel/miniconda3/envs/colossal/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1527, in _call_impl return forward_call(*args, **kwargs) File "/data/llmodel/miniconda3/envs/colossal/lib/python3.10/site-packages/torch/utils/_contextlib.py", line 115, in decorate_context return func(*args, **kwargs) TypeError: LlamaRotaryEmbedding.forward() got an unexpected keyword argument 'seq_len'
Environment
python 3.10 transformers 4.39.2 colossalai 0..3.6
Hi😃This error is because since the transformers v4.39, the arguments seq_len is removed from LlamaRotaryEmbedding.forward(). But the code for ColossalLlama was written even further back (I guess it was around v4.34). At that time, the Flash Attention technique, which significantly speeds up attention and reduces memory consumption, had just come out and hadn't been integrated into LlamaAttention. That's why we need a flash_attn_patch to enable this feature back then. This patch is based on a function signature from an older version of Transformers.
But for now, the Flash Attention has already be integrated to Huggingface Llama Implementation(see classes
LlamaFlashAttention2 and LlamaSdpaAttention). So I think you can just set use_flash_attn to False and Llama Model will automatically use the flash attention feature now. I believe later this patch will be removed.
when I change transformer into 4.38.0, it shows
File "/home/user1/workspace/colossal-ai/ColossalAI/examples/language/llama2/attn.py", line 133, in attention_forward
cos, sin = self.rotary_emb(v, seq_len=kv_len)
File "/home/user1/anaconda3/envs/colossalai/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1501, in _call_impl
return forward_call(*args, **kwargs)
TypeError: LlamaRotaryEmbedding.forward() missing 1 required positional argument: 'position_ids'
So, which version of transformer should I use with flash attention?
Transformers v4.37 is OK. But just as I said, you can use v4.39 and still enjoy the speedup from flash_attn by setting use_flash_attn to False. Because flash attention has been integrated to transformers library without needing our patch.
Transformers v4.37 is OK. But just as I said, you can use v4.39 and still enjoy the speedup from flash_attn by setting
use_flash_attntoFalse. Because flash attention has been integrated to transformers library without needing our patch.
Hi, looks like if I set the use_flash_attn to Flase, the GPU memory will increase.
and here is my env:
Package Version
------------------------- -----------
absl-py 2.1.0
aiohttp 3.9.3
aiosignal 1.3.1
annotated-types 0.6.0
async-timeout 4.0.3
attrs 23.2.0
bcrypt 4.1.2
beautifulsoup4 4.12.3
cachetools 5.3.3
certifi 2024.2.2
cffi 1.16.0
cfgv 3.4.0
charset-normalizer 3.3.2
click 8.1.7
cmake 3.29.0.1
colossalai 0.3.6
contexttimer 0.3.3
cryptography 42.0.5
datasets 2.18.0
decorator 5.1.1
Deprecated 1.2.14
dill 0.3.8
distlib 0.3.8
dropout-layer-norm 0.1
einops 0.7.0
fabric 3.2.2
filelock 3.13.3
flash-attn 2.2.1
frozenlist 1.4.1
fsspec 2024.2.0
fused-dense-lib 0.0.0
google 3.0.0
google-auth 2.29.0
google-auth-oauthlib 1.0.0
grpcio 1.62.1
huggingface-hub 0.22.2
identify 2.5.35
idna 3.6
invoke 2.2.0
Jinja2 3.1.3
jsonschema 4.21.1
jsonschema-specifications 2023.12.1
lit 18.1.2
Markdown 3.6
markdown-it-py 3.0.0
MarkupSafe 2.1.5
mdurl 0.1.2
mpmath 1.3.0
msgpack 1.0.8
multidict 6.0.5
multiprocess 0.70.16
networkx 3.3
ninja 1.11.1.1
nodeenv 1.8.0
numpy 1.26.4
nvidia-cublas-cu11 11.10.3.66
nvidia-cublas-cu12 12.1.3.1
nvidia-cuda-cupti-cu11 11.7.101
nvidia-cuda-cupti-cu12 12.1.105
nvidia-cuda-nvrtc-cu11 11.7.99
nvidia-cuda-nvrtc-cu12 12.1.105
nvidia-cuda-runtime-cu11 11.7.99
nvidia-cuda-runtime-cu12 12.1.105
nvidia-cudnn-cu11 8.5.0.96
nvidia-cudnn-cu12 8.9.2.26
nvidia-cufft-cu11 10.9.0.58
nvidia-cufft-cu12 11.0.2.54
nvidia-curand-cu11 10.2.10.91
nvidia-curand-cu12 10.3.2.106
nvidia-cusolver-cu11 11.4.0.1
nvidia-cusolver-cu12 11.4.5.107
nvidia-cusparse-cu11 11.7.4.91
nvidia-cusparse-cu12 12.1.0.106
nvidia-nccl-cu11 2.14.3
nvidia-nccl-cu12 2.19.3
nvidia-nvjitlink-cu12 12.4.127
nvidia-nvtx-cu11 11.7.91
nvidia-nvtx-cu12 12.1.105
oauthlib 3.2.2
packaging 24.0
pandas 2.2.1
paramiko 3.4.0
pip 23.3.1
platformdirs 4.2.0
pre-commit 3.7.0
protobuf 5.26.1
psutil 5.9.8
pyarrow 15.0.2
pyarrow-hotfix 0.6
pyasn1 0.6.0
pyasn1_modules 0.4.0
pycparser 2.22
pydantic 2.6.4
pydantic_core 2.16.3
Pygments 2.17.2
PyNaCl 1.5.0
python-dateutil 2.9.0.post0
pytz 2024.1
PyYAML 6.0.1
ray 2.10.0
referencing 0.34.0
regex 2023.12.25
requests 2.31.0
requests-oauthlib 2.0.0
rich 13.7.1
rotary-emb 0.1
rpds-py 0.18.0
rsa 4.9
safetensors 0.4.2
sentencepiece 0.1.99
setuptools 68.2.2
six 1.16.0
soupsieve 2.5
sympy 1.12
tensorboard 2.14.0
tensorboard-data-server 0.7.2
tokenizers 0.13.3
torch 2.0.0
tqdm 4.66.2
transformers 4.33.3
triton 2.0.0
typing_extensions 4.11.0
tzdata 2024.1
urllib3 2.2.1
virtualenv 20.25.1
Werkzeug 3.0.2
wheel 0.41.2
wrapt 1.16.0
xentropy-cuda-lib 0.1
xxhash 3.4.1
yarl 1.9.4