Missing dependency on JAX and libJAX in gptj_ckpt_convert.py
Branch/Tag/Commit
main
Docker Image Version
nvcr.io/nvidia/pytorch:22.09-py3
GPU name
AGX Orin
CUDA Driver
520.61.03
Reproduced Steps
Build FasterTransformer docker container.
In conainer, with GPT-J weights accessible via /ft_workspace:
python3 /workspace/FasterTransformer/examples/pytorch/gptj/utils/gptj_ckpt_convert.py --output-dir /ft_workspace/j6b_ckpt --ckpt-dir /ft_workspace/step_383500/ --n-inference-gpus 1
Yields python error that jax and jaxlib are not found.
They can be built using instructions found here https://jax.readthedocs.io/en/latest/developer.html#building-from-source
Thiese build steps should b added to the Dockerfile, but even once built and installed, gptj_ckpt_convert.py using command line above throws the following error:
loading
loading shards for part 0
read from checkpoint
Traceback (most recent call last):
File "/workspace/FasterTransformer/examples/pytorch/gptj/utils/gptj_ckpt_convert.py", line 277, in <module>
checkpoint = main(in_path, num_layers)
File "/workspace/FasterTransformer/examples/pytorch/gptj/utils/gptj_ckpt_convert.py", line 235, in main
old_shape = (params.shape[1],)
IndexError: tuple index out of range
Do you install the required packages mentioned in gptj_guide.md?
I don't think I missed any, but I'll repeat it from scratch in the morning to be certain. EDIT: A quick review of gpj_guide.md includes the instrucion o insall jax and jaxlib via pip" pip3 install fire jax jaxlib however, review of the jax and libjax documentation mentioned earlier clearly establishes that the versions you can get via pip wheels are NOT cuda enabled. Is clear that the only way to get the cuda enabled versions is to build from source, which is why I suggest it be built in the FasterTransformer Dockerfile/image. That said, I'm happy o try using the non-CUDA enabled versions, but it seems odd that would be the recommended/supported approach.
On Thu, Feb 23, 2023, 19:52 byshiue @.***> wrote:
Do you install the required packages mentioned in gptj_guide.md?
— Reply to this email directly, view it on GitHub https://github.com/NVIDIA/FasterTransformer/issues/471#issuecomment-1442637233, or unsubscribe https://github.com/notifications/unsubscribe-auth/AABHNRYXUJ6N5E4NKJ3L46TWZAA3ZANCNFSM6AAAAAAVGG65BQ . You are receiving this because you authored the thread.Message ID: @.***>
As it turns out, there is no jaxlib (or libjax as the jax maintainers reference it) wheel distribuion for ARM so it mus be buil from source. The one thing I haven't tried - and I'll do so in the morning - is to build jax and jaxlib wihou CUDA enabled, alhough I wouldn't expect such a difference o yield the trace and errors we're seeing in execution of gptj_ckpt_convert.py
It turns out that there's an issue with the semi-monolithic build approach taken by the Jax maintainers. While I understand this approach, in that it yields a predictable output, it yields conflict when libraries in the naiv build server environment are not utilized. My firs step in resolving the issues I encountered was o build Jaxlib and Jax v0.4.5 from source without cuda or tpu support. I then found a number of python module version conflict not resolved by pip3 when installing jaxlib and jax. I was able to resolve these down to patch version errors as reported by pip3 by updating the following:
pip3 install thinc spacy numba cudf pydantic
Wih h version error reports limited now to patch versions, I then attempted o perform the weights conversion operation, with apparent success. I will in the next day or two, update the ticket with further findings and specific version numbers.
As a general matter, the solution here will involve either a number of other projects updating their dependency versioning or recommending a legacy version of JAX, here for use with FasterTransformer on Arm64, and a las in the case of Arm, building it from source. Further updates to come....
UPDATE: I appears only the slim weights file was processed successfully. When processing he full weights file, he same error was thrown (this is likely due to my only partial version conflict resolution above):
root@653b52a72f3b:/workspace/jax# python3 /workspace/FasterTransformer/examples/pytorch/gptj/utils/gptj_ckpt_convert.py --output-dir /ft_workspace/full/j6b_ckpt --ckpt-dir /ft_workspace/full/step_383500/
loading
loading shards for part 0
read from checkpoint
Traceback (most recent call last):
File "/workspace/FasterTransformer/examples/pytorch/gptj/utils/gptj_ckpt_convert.py", line 277, in
On my most recent attempt, built using he 0.4.4 tag (rahr than main which seems to be 0.4.5) he following conflicts are reported upon installation:
root@abfe88053f4b:/workspace/jax# cp dist/jaxlib-0.4.4-cp38-cp38-manylinux2014_aarch64.whl /ft_workspace/
root@abfe88053f4b:/workspace/jax# pip install /workspace/jax/dist/jaxlib-0.4.4-cp38-cp38-manylinux2014_aarch64.whl --force-reinstall
Looking in indexes: https://pypi.org/simple, https://pypi.ngc.nvidia.com
Processing ./dist/jaxlib-0.4.4-cp38-cp38-manylinux2014_aarch64.whl
Collecting scipy>=1.5
Downloading scipy-1.10.1-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl (31.0 MB)
━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 31.0/31.0 MB 3.9 MB/s eta 0:00:00
Collecting numpy>=1.20
Downloading numpy-1.24.2-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl (14.0 MB)
━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 14.0/14.0 MB 3.9 MB/s eta 0:00:00
Installing collected packages: numpy, scipy, jaxlib
Attempting uninstall: numpy
Found existing installation: numpy 1.22.2
Uninstalling numpy-1.22.2:
Successfully uninstalled numpy-1.22.2
Attempting uninstall: scipy
Found existing installation: scipy 1.6.3
Uninstalling scipy-1.6.3:
Successfully uninstalled scipy-1.6.3
ERROR: pip's dependency resolver does not currently take into account all the packages that are installed. This behaviour is the source of the following dependency conflicts.
thinc 8.0.13 requires pydantic!=1.8,!=1.8.1,<1.9.0,>=1.7.4, but you have pydantic 1.10.5 which is incompatible.
spacy 3.2.2 requires pydantic!=1.8,!=1.8.1,<1.9.0,>=1.7.4, but you have pydantic 1.10.5 which is incompatible.
numba 0.56.2 requires numpy<1.24,>=1.18, but you have numpy 1.24.2 which is incompatible.
cudf 22.8.0a0+304.g6ca81bbc78.dirty requires protobuf<3.21.0a0,>=3.20.1, but you have protobuf 3.19.6 which is incompatible.
Successfully installed jaxlib-0.4.4 numpy-1.24.2 scipy-1.10.1
WARNING: Running pip as the 'root' user can result in broken permissions and conflicting behaviour with the system package manager. It is recommended to use a virtual environment instead: https://pip.pypa.io/warnings/venv
root@abfe88053f4b:/workspace/jax#
No conflics are repored upon root@abfe88053f4b:/workspace/jax# pip install -e .
Upon manual versionsspecific installation of:
pip3 install pydantic==1.8.2
pip3 install numpy==1.23.4
pip install protobuf==3.20.3
resulted in the remaining errors being of patch level only:
ERROR: pip's dependency resolver does not currently take into account all the packages that are installed. This behaviour is the source of the following dependency conflicts.
tensorboard 2.10.0 requires protobuf<3.20,>=3.9.2, but you have protobuf 3.20.3 which is incompatible.
pytorch-lightning 1.6.5 requires protobuf<=3.20.1, but you have protobuf 3.20.3 which is incompatible.
onnx 1.12.0 requires protobuf<=3.20.1,>=3.12.2, but you have protobuf 3.20.3 which is incompatible.
It turns out that Jaxlib v0.4.3 is the newest version that doesn't result is circular/conflicting dependencies. Unfortunately, it seems that gptj_ckpt_convert.py only successfully processes the slim weights. I don' know whether the full weighs format has changed, or the script was never designed to parse the full weights data (including the metadata). In any case, having resolved the dependency issues, we sill get errors when parsing the full weights:
python3 /workspace/FasterTransformer/examples/pytorch/gptj/utils/gptj_ckpt_convert.py --output-dir /ft_workspace/full/j6b_ckpt --ckpt-dir /ft_workspace/full/step_383500/
loading
loading shards for part 0
read from checkpoint
Traceback (most recent call last):
File "/workspace/FasterTransformer/examples/pytorch/gptj/utils/gptj_ckpt_convert.py", line 277, in
Hi, bro. Did you success at last?
I had to build them from source as Nvidia does an exceptionally poor job of package management for h Jetson line. I turned out not to matter though, as I nex discovered that fasterTransformer_backend does no support sm_87. As far as I know his has no changed even after Jensen Huang announced publicly at GTC 2023 that the Triton Server stack which of course includes FasterTransformer, "Is Now" (As of GTC 23) supported across ALL of Nvidia's hardware offerings. Clearly, his turned out to be a lie.
Got it, thanks for your detailed reply and hard work.