Error while running bash command: run_sample_video.sh | Error: "TypeError: missing a required argument: 'segment_ids'"
I receive this error when i run this bash command: !bash LWM/scripts/run_sample_video.sh. I have followed all the direction listed in the repo.
/usr/local/lib/python3.10/dist-packages/huggingface_hub/file_download.py:1132: FutureWarning: `resume_download` is deprecated and will be removed in version 1.0.0. Downloads always resume when possible. If you want to force a new download, use `force_download=True`.
warnings.warn(
You are using the default legacy behaviour of the <class 'transformers.models.llama.tokenization_llama.LlamaTokenizer'>. This is expected, and simply means that the `legacy` (previous) behavior will be used so nothing changes for you. If you want to use the new behaviour, set `legacy=False`. This should only be set if you understand what it means, and thoroughly read the reason why this was added as explained in https://github.com/huggingface/transformers/pull/24565
jax.errors.SimplifiedTraceback: For simplicity, JAX has removed its internal frames from the traceback of the following exception. Set JAX_TRACEBACK_FILTERING=off to include these.
The above exception was the direct cause of the following exception:
Traceback (most recent call last):
File "/usr/lib/python3.10/runpy.py", line 196, in _run_module_as_main
return _run_code(code, main_globals, None,
File "/usr/lib/python3.10/runpy.py", line 86, in _run_code
exec(code, run_globals)
File "/content/LWM/lwm/vision_generation.py", line 256, in <module>
run(main)
File "/usr/local/lib/python3.10/dist-packages/absl/app.py", line 308, in run
_run_main(main, args)
File "/usr/local/lib/python3.10/dist-packages/absl/app.py", line 254, in _run_main
sys.exit(main(argv))
File "/content/LWM/lwm/vision_generation.py", line 92, in main
model = FlaxVideoLLaMAForCausalLM(
File "/content/LWM/lwm/vision_llama.py", line 141, in __init__
super().__init__(config, module, input_shape=input_shape, seed=seed, dtype=dtype, _do_init=_do_init)
File "/usr/local/lib/python3.10/dist-packages/transformers/modeling_flax_utils.py", line 224, in __init__
params_shape_tree = jax.eval_shape(init_fn, self.key)
File "/content/LWM/lwm/vision_llama.py", line 166, in init_weights
random_params = self.module.init(rngs, input_ids, vision_masks, attention_mask, segment_ids, position_ids, return_dict=False)["params"]
File "/content/LWM/lwm/vision_llama.py", line 396, in __call__
outputs = self.transformer(
File "/content/LWM/lwm/vision_llama.py", line 315, in __call__
outputs = self.h(
File "/content/LWM/lwm/llama.py", line 945, in __call__
hidden_states, _ = nn.scan(
File "/usr/local/lib/python3.10/dist-packages/flax/core/axes_scan.py", line 151, in scan_fn
_, out_pvals, _ = pe.trace_to_jaxpr_nounits(f_flat, in_pvals)
File "/usr/local/lib/python3.10/dist-packages/flax/core/axes_scan.py", line 123, in body_fn
broadcast_out, c, ys = fn(broadcast_in, c, *xs)
File "/content/LWM/lwm/llama.py", line 724, in __call__
attn_outputs = self.attention(
File "/content/LWM/lwm/llama.py", line 615, in __call__
attn_output = ring_attention_sharded(
File "/usr/lib/python3.10/inspect.py", line 3186, in bind
return self._bind(args, kwargs)
File "/usr/lib/python3.10/inspect.py", line 3101, in _bind
raise TypeError(msg) from None
TypeError: missing a required argument: 'segment_ids'
Would appreciate some help here.
Seeing the same error. Commit 97ae4b672f0a9d8bc30ab536d4bac42c3d044aff works for me on GPU
@gabeweisz I get the following error:
(lwm) madhu@madhupc:~/LWM$ bash scripts/run_sample_image.sh
Traceback (most recent call last):
File "/home/madhu/anaconda3/envs/lwm/lib/python3.10/runpy.py", line 196, in _run_module_as_main
return _run_code(code, main_globals, None,
File "/home/madhu/anaconda3/envs/lwm/lib/python3.10/runpy.py", line 86, in runcode
exec(code, runglobals)
File "/home/madhu/LWM/lwm/visiongeneration.py", line 11, in
This is our google colab, do you mind taking a look and telling us changes should be made to run this model.
https://colab.research.google.com/drive/1Bx-wRzOspvq5JLctNKRHwHq-vIgw7wlv?usp=sharing
For the version of the repo I pointed you to, it works for me using Jax 0.4.25 and with flax==0.8.2 and chex==0.1.86
I'm not part of your google collaboration, but maybe the authors of this project will chime in with more information
For the version of the repo I pointed you to, it works for me using Jax 0.4.25 and with flax==0.8.2 and chex==0.1.86
I'm not part of your google collaboration, but maybe the authors of this project will chime in with more information
So for gpus you used Commit 97ae4b6 and solely followed the instructions for that specific version? Or did you run some other commands? Also, do you mind showing me your entire requirements txt file? The versions in the requirements.txt from 97ae4b6 are different from what you mentioned. I am struggling to get this working with my gpu.
I used commit 97ae4b6 and did not change anything.
I installed packages using the requirements.txt in that commit, and then updated the two packages that I mention above manually using pip.
I most likely have a different GPU than you do, but this is what worked for me.
The newest commit (https://github.com/LargeWorldModel/LWM/commit/b8e36023d17d965f40071dcbc6dcdb1865d84a49) fixes this error for me.