Rafi Witten
Rafi Witten
New crash: Traceback (most recent call last): File "/home/rwitten/maxtext/pedagogical_examples/host_offload.py", line 55, in host_out_shardings = jax.tree.map(lambda x : x.with_memory_kind('pinned_host'), shardings) File "/home/rwitten/.local/lib/python3.10/site-packages/jax/_src/tree.py", line 61, in map return tree_util.tree_map(f, tree, *rest, is_leaf=is_leaf)...
Amazing @borisdayma! We don't actually official support Mistral (we do support Llama and Gemma) but we're thrilled things are working for you!
Thank you for the comments! (1) Fused attention is on by default for training! We use "splash attention" which is a custom and faster version! (And we're working on accelerated...
@A9isha
@logicchains thanks for the tips on GPU convergence! We will experiment with this as we set up convergent regimes for GPUs. @anfals please be aware of this as you do...
We don't support that out of the box. We've found that tuning LR to be smaller is a better approach. What is your use case?
What specific model would you like supported? We would only take this on if we saw sufficient interest (but in practice we see heavy movement towards decoder-only models).
Thank you good samaritan! This code looks good but annoyingly you need to sign the CLA for us to be allowed to merge it.
export M_BASE_OUTPUT_DIRECTORY=gs://runner-maxtext-logs export M_DATASET_PATH=gs://maxtext-dataset python MaxText/decode.py MaxText/configs/base.yml per_device_batch_size=.25 run_name=runner_$(date +%Y-%m-%d-%H-%M) max_prefill_predict_length=8 max_target_length=16 scan_layers=true attention=dot_product prompt="I love to" global_parameter_scale=1 add_eos=false ici_autoregressive_parallelism=4
I'd recommend not using this optimizer -- it is only for MLPerf. @ZhiyuLi-goog -- can you look at the (quite scary) warning?