While run the command of "bash scripts/run_vision_chat.sh". Error happended .How to fix it.
(lwm) llm@llm-PowerEdge-R730xd:~/projects/LWM-main$ bash scripts/run_vision_chat.sh
I0221 14:02:43.257625 139932541391232 xla_bridge.py:660] Unable to initialize backend 'rocm': NOT_FOUND: Could not find registered platform with name: "rocm". Available platform names are: CUDA
I0221 14:02:43.260045 139932541391232 xla_bridge.py:660] Unable to initialize backend 'tpu': INTERNAL: Failed to open libtpu.so: libtpu.so: cannot open shared object file: No such file or directory
100%|██████████| 1/1 [00:05<00:00, 5.59s/it]
Traceback (most recent call last):
File "/home/llm/anaconda3/envs/lwm/lib/python3.10/runpy.py", line 196, in _run_module_as_main
return _run_code(code, main_globals, None,
File "/home/llm/anaconda3/envs/lwm/lib/python3.10/runpy.py", line 86, in _run_code
exec(code, run_globals)
File "/home/llm/projects/LWM-main/lwm/vision_chat.py", line 254, in
run(main)
File "/home/llm/anaconda3/envs/lwm/lib/python3.10/site-packages/absl/app.py", line 308, in run
_run_main(main, args)
File "/home/llm/anaconda3/envs/lwm/lib/python3.10/site-packages/absl/app.py", line 254, in _run_main
sys.exit(main(argv))
File "/home/llm/projects/LWM-main/lwm/vision_chat.py", line 250, in main
output = sampler(prompts, FLAGS.max_n_frames)[0]
File "/home/llm/projects/LWM-main/lwm/vision_chat.py", line 230, in call
output, self.sharded_rng = self._forward_generate(
File "/home/llm/anaconda3/envs/lwm/lib/python3.10/site-packages/jax/_src/traceback_util.py", line 179, in reraise_with_filtered_traceback
return fun(*args, **kwargs)
File "/home/llm/anaconda3/envs/lwm/lib/python3.10/site-packages/jax/_src/pjit.py", line 257, in cache_miss
outs, out_flat, out_tree, args_flat, jaxpr = _python_pjit_helper(
File "/home/llm/anaconda3/envs/lwm/lib/python3.10/site-packages/jax/_src/pjit.py", line 163, in _python_pjit_helper
args_flat, _, params, in_tree, out_tree, _, _, _ = infer_params_fn(
File "/home/llm/anaconda3/envs/lwm/lib/python3.10/site-packages/jax/_src/pjit.py", line 781, in infer_params
return common_infer_params(pjit_info_args, *args, **kwargs)
File "/home/llm/anaconda3/envs/lwm/lib/python3.10/site-packages/jax/_src/pjit.py", line 493, in common_infer_params
jaxpr, consts, canonicalized_out_shardings_flat, out_layouts_flat = _pjit_jaxpr(
File "/home/llm/anaconda3/envs/lwm/lib/python3.10/site-packages/jax/_src/pjit.py", line 996, in _pjit_jaxpr
jaxpr, final_consts, out_type = _create_pjit_jaxpr(
File "/home/llm/anaconda3/envs/lwm/lib/python3.10/site-packages/jax/_src/linear_util.py", line 349, in memoized_fun
ans = call(fun, *args)
File "/home/llm/anaconda3/envs/lwm/lib/python3.10/site-packages/jax/_src/pjit.py", line 936, in _create_pjit_jaxpr
jaxpr, global_out_avals, consts = pe.trace_to_jaxpr_dynamic(
File "/home/llm/anaconda3/envs/lwm/lib/python3.10/site-packages/jax/_src/profiler.py", line 336, in wrapper
return func(*args, **kwargs)
File "/home/llm/anaconda3/envs/lwm/lib/python3.10/site-packages/jax/_src/interpreters/partial_eval.py", line 2288, in trace_to_jaxpr_dynamic
jaxpr, out_avals, consts = trace_to_subjaxpr_dynamic(
File "/home/llm/anaconda3/envs/lwm/lib/python3.10/site-packages/jax/src/interpreters/partial_eval.py", line 2310, in trace_to_subjaxpr_dynamic
ans = fun.call_wrapped(*in_tracers)
File "/home/llm/anaconda3/envs/lwm/lib/python3.10/site-packages/jax/_src/linear_util.py", line 191, in call_wrapped
ans = self.f(*args, **dict(self.params, **kwargs))
File "/home/llm/projects/LWM-main/lwm/vision_chat.py", line 206, in fn
output = self.model.generate(
File "/home/llm/anaconda3/envs/lwm/lib/python3.10/site-packages/transformers/generation/flax_utils.py", line 429, in generate
return self._sample(
File "/home/llm/anaconda3/envs/lwm/lib/python3.10/site-packages/transformers/generation/flax_utils.py", line 733, in _sample
state = sample_search_body_fn(state)
File "/home/llm/anaconda3/envs/lwm/lib/python3.10/site-packages/transformers/generation/flax_utils.py", line 704, in sample_search_body_fn
model_outputs = model(state.running_token, params=params, **state.model_kwargs)
File "/home/llm/projects/LWM-main/lwm/vision_llama.py", line 232, in call
outputs = self.module.apply(
File "/home/llm/anaconda3/envs/lwm/lib/python3.10/site-packages/jax/_src/traceback_util.py", line 179, in reraise_with_filtered_traceback
return fun(*args, **kwargs)
File "/home/llm/anaconda3/envs/lwm/lib/python3.10/site-packages/flax/linen/module.py", line 1511, in apply
return apply(
File "/home/llm/anaconda3/envs/lwm/lib/python3.10/site-packages/flax/core/scope.py", line 934, in wrapper
y = fn(root, *args, **kwargs)
File "/home/llm/anaconda3/envs/lwm/lib/python3.10/site-packages/flax/linen/module.py", line 2082, in scope_fn
return fn(module.clone(parent=scope, _deep_clone=True), *args, **kwargs)
File "/home/llm/anaconda3/envs/lwm/lib/python3.10/site-packages/flax/linen/module.py", line 418, in wrapped_module_method
return self._call_wrapped_method(fun, args, kwargs)
File "/home/llm/anaconda3/envs/lwm/lib/python3.10/site-packages/flax/linen/module.py", line 854, in _call_wrapped_method
y = fun(self, *args, **kwargs)
File "/home/llm/projects/LWM-main/lwm/vision_llama.py", line 401, in call
outputs = self.transformer(
File "/home/llm/anaconda3/envs/lwm/lib/python3.10/site-packages/flax/linen/module.py", line 418, in wrapped_module_method
return self._call_wrapped_method(fun, args, kwargs)
File "/home/llm/anaconda3/envs/lwm/lib/python3.10/site-packages/flax/linen/module.py", line 854, in _call_wrapped_method
y = fun(self, *args, **kwargs)
File "/home/llm/projects/LWM-main/lwm/vision_llama.py", line 313, in call
input_text_embeds = self.wte(jnp.where(vision_masks, 0, input_ids))
File "/home/llm/anaconda3/envs/lwm/lib/python3.10/site-packages/flax/linen/module.py", line 418, in wrapped_module_method
return self._call_wrapped_method(fun, args, kwargs)
File "/home/llm/anaconda3/envs/lwm/lib/python3.10/site-packages/flax/linen/module.py", line 836, in _call_wrapped_method
self._try_setup()
File "/home/llm/anaconda3/envs/lwm/lib/python3.10/site-packages/flax/linen/module.py", line 1094, in _try_setup
self.setup()
File "/home/llm/anaconda3/envs/lwm/lib/python3.10/site-packages/flax/linen/module.py", line 418, in wrapped_module_method
return self._call_wrapped_method(fun, args, kwargs)
File "/home/llm/anaconda3/envs/lwm/lib/python3.10/site-packages/flax/linen/module.py", line 854, in _call_wrapped_method
y = fun(self, *args, **kwargs)
File "/home/llm/anaconda3/envs/lwm/lib/python3.10/site-packages/flax/linen/linear.py", line 771, in setup
self.embedding = self.param('embedding',
File "/home/llm/anaconda3/envs/lwm/lib/python3.10/site-packages/flax/linen/module.py", line 1263, in param
v = self.scope.param(name, init_fn, *init_args, unbox=unbox)
File "/home/llm/anaconda3/envs/lwm/lib/python3.10/site-packages/flax/core/scope.py", line 842, in param
raise errors.ScopeParamNotFoundError(name, self.path_text)
flax.errors.ScopeParamNotFoundError: Could not find parameter named "embedding" in scope "/transformer/wte". (https://flax.readthedocs.io/en/latest/api_reference/flax.errors.html#flax.errors.ScopeParamNotFoundError)
Thank you in advance. Related info see as belows.
run_vision_chat.sh
#! /bin/bash
export SCRIPT_DIR="$( cd -- "$( dirname -- "${BASH_SOURCE[0]}" )" &> /dev/null && pwd )"
export PROJECT_DIR="$( cd -- "$( dirname -- "$SCRIPT_DIR" )" &> /dev/null && pwd )"
cd $PROJECT_DIR
export PYTHONPATH="$PYTHONPATH:$PROJECT_DIR"
export llama_tokenizer_path="LWM-Chat-1M-Jax/tokenizer.model"
export vqgan_checkpoint="LWM-Chat-1M-Jax/vqgan"
export lwm_checkpoint="LWM-Chat-1M-Jax/params"
export input_file="demo.jpg"
python3 -u -m lwm.vision_chat
--prompt="What is the image about?"
--input_file="$input_file"
--vqgan_checkpoint="$vqgan_checkpoint"
--dtype='fp32'
--load_llama_config='7b'
--max_n_frames=8
--update_llama_config="dict(sample_mode='text',theta=50000000,max_sequence_length=131072,use_flash_attention=False,scan_attention=False,scan_query_chunk_size=128,scan_key_chunk_size=128,remat_attention='',scan_mlp=False,scan_mlp_chunk_size=2048,remat_mlp='',remat_block='',scan_layers=True)"
--load_checkpoint="params::$lwm_checkpoint"
--tokenizer.vocab_file="$llama_tokenizer_path"
2>&1 | tee ~/output.log
read
pip list
Package Version
absl-py 2.1.0
aiohttp 3.9.3
aiosignal 1.3.1
appdirs 1.4.4
asttokens 2.4.1
async-timeout 4.0.3
attrs 23.2.0
build 1.0.3
cachetools 5.3.2
certifi 2024.2.2
charset-normalizer 3.3.2
chex 0.1.82
click 8.1.7
cloudpickle 3.0.0
contextlib2 21.6.0
datasets 2.13.0
decorator 5.1.1
decord 0.6.0
dill 0.3.6
docker-pycreds 0.4.0
einops 0.7.0
etils 1.7.0
exceptiongroup 1.2.0
executing 2.0.1
filelock 3.13.1
flax 0.7.0
frozenlist 1.4.1
fsspec 2024.2.0
gcsfs 2024.2.0
gitdb 4.0.11
GitPython 3.1.42
google-api-core 2.17.1
google-auth 2.28.0
google-auth-oauthlib 1.2.0
google-cloud-core 2.4.1
google-cloud-storage 2.14.0
google-crc32c 1.5.0
google-resumable-media 2.7.0
googleapis-common-protos 1.62.0
huggingface-hub 0.20.3
idna 3.6
imageio 2.34.0
imageio-ffmpeg 0.4.9
importlib-resources 6.1.1
ipdb 0.13.13
ipython 8.21.0
jax 0.4.23
jaxlib 0.4.23+cuda12.cudnn89
jedi 0.19.1
markdown-it-py 3.0.0
matplotlib-inline 0.1.6
mdurl 0.1.2
ml-collections 0.1.1
ml-dtypes 0.3.2
msgpack 1.0.7
multidict 6.0.5
multiprocess 0.70.14
nest-asyncio 1.6.0
numpy 1.26.4
nvidia-cublas-cu12 12.3.4.1
nvidia-cuda-cupti-cu12 12.3.101
nvidia-cuda-nvcc-cu12 12.3.107
nvidia-cuda-nvrtc-cu12 12.3.107
nvidia-cuda-runtime-cu12 12.3.101
nvidia-cudnn-cu12 8.9.7.29
nvidia-cufft-cu12 11.0.12.1
nvidia-cusolver-cu12 11.5.4.101
nvidia-cusparse-cu12 12.2.0.103
nvidia-nccl-cu12 2.19.3
nvidia-nvjitlink-cu12 12.3.101
oauthlib 3.2.2
opt-einsum 3.3.0
optax 0.1.7
orbax-checkpoint 0.5.3
packaging 23.2
pandas 2.2.0
parso 0.8.3
pexpect 4.9.0
pillow 10.2.0
pip 23.3.1
prompt-toolkit 3.0.43
protobuf 4.25.3
psutil 5.9.8
ptyprocess 0.7.0
pure-eval 0.2.2
pyarrow 15.0.0
pyasn1 0.5.1
pyasn1-modules 0.3.0
Pygments 2.17.2
pyproject_hooks 1.0.0
python-dateutil 2.8.2
pytz 2024.1
PyYAML 6.0.1
regex 2023.12.25
requests 2.31.0
requests-oauthlib 1.3.1
rich 13.7.0
rsa 4.9
scipy 1.12.0
sentencepiece 0.2.0
sentry-sdk 1.40.5
setproctitle 1.3.3
setuptools 68.2.2
six 1.16.0
smmap 5.0.1
stack-data 0.6.3
tensorstore 0.1.53
tiktoken 0.6.0
tokenizers 0.13.3
tomli 2.0.1
toolz 0.12.1
tqdm 4.66.2
traitlets 5.14.1
transformers 4.29.2
tux 0.0.2
typing_extensions 4.9.0
tzdata 2024.1
urllib3 2.2.1
wandb 0.16.3
wcwidth 0.2.13
wheel 0.41.2
xxhash 3.4.1
yarl 1.9.4
zipp 3.17.0
Thanks, Which model did you use? And are you ok to run "bash run_vision_chat.sh"?
hello,I encountered the same problem, and eventually found out that the model was not fully uploaded to the server. Can you please check whether the model file size at your end is consistent? If the command runs normally, it should not throw an error. If you encounter any issues, please send it over again for further review.你应该是模型文件没下载完整,或者传输时候没传完整,但没看到这个问题,你配置的环境也没问题的,jax和flax库的版本都是对的