Migrate Gpt3 to NNX.
Description
This PR
- Migrate Gpt3 implementation from Linen to NNX.
Including the following classes:
- Gpt3MultiHeadAttention
- Gpt3DecoderLayer
- Fix the decode function of Gpt3, by
- Cast
decoder_positionstoint32- The trainable position embedding layer (Embed layer) requires integer indices for its lookup, but decoder_positions was passed as a float. This casts it to int32 to prevent the ValueError.
Tests
Ran train command to train gpt3-6b for 10 steps:
python3 -m MaxText.train MaxText/configs/base.yml run_name=gpt3-train-run base_output_directory=gs://maxtext-test/train_gpt3/1/ model_name=gpt3-6b dataset_type=synthetic steps=10
Logs: Linen, before migration NNX, after migration
Profile: Linen, before migration NNX, after migration
Checklist
Before submitting this PR, please make sure (put X in square brackets):
- [X] I have performed a self-review of my code.
- [X] I have necessary comments in my code, particularly in hard-to-understand areas.
- [X] I have run end-to-end tests tests and provided workload links above if applicable.
- [X] I have made or will make corresponding changes to the doc if needed.
Awesome to see this @hsuan-lun-chiang. Could you please add before/after logs for the training command shown in your test section? It would be great to see that perf is the same before/after here
Update the description with the before/after logs, thank you.
Results for Train and Jetstream
Test Environment
Machine Type: TPU V6e-8
Train
Executed Command:
python3 -m MaxText.train \
MaxText/configs/base.yml \
run_name=gpt3-train-run \
base_output_directory=gs://lance-maxtext/gpt3-6b-train-before/ \
model_name=gpt3-6b \
dataset_type=synthetic \
steps=10
Results:
Maxengine / Jetstream
Step 1: Launch Maxengine
# On terminal 1
export LIBTPU_INIT_ARGS="--xla_jf_auto_cross_replica_sharding=false --xla_tpu_enable_windowed_einsum_for_reduce_scatter=false --xla_tpu_enable_windowed_einsum_for_all_gather=false --xla_tpu_prefer_latch_optimized_rhs_layouts=true --xla_tpu_enable_experimental_fusion_cost_model=false --xla_tpu_dot_dot_fusion_duplicated=false --xla_tpu_dot_dot_fusion=true --xla_jf_conv_input_fusion=true --xla_jf_conv_output_fusion=true --xla_tpu_rwb_fusion=false --xla_tpu_copy_fusion_pad_unpad_ratio=0 --xla_tpu_licm_size_inflation_ratio=1 --xla_tpu_copy_elision_analysis_allowance=150000 --xla_tpu_copy_insertion_use_region_analysis_limit=10000 --xla_tpu_order_dot_after_layout=true --xla_jf_rematerialization_percent_shared_memory_limit=100 --xla_tpu_use_repeated_instance_for_preferred_prefetch_time=true --xla_tpu_enforce_prefetch_fifo_order=false --xla_tpu_prefetch_interval_picker_size_override=6000000 --xla_tpu_async_copy_bandwidth_scaling_factor=1 --xla_tpu_nd_short_transfer_max_chunks=-1 --xla_tpu_enable_aggressive_broadcast_priority_update=true --xla_tpu_alternate_memory_benefit_scaling_factor_for_large_buffers=SQRT --xla_tpu_memory_bound_loop_optimizer_options=enabled:true --xla_tpu_enable_copy_fusion=true --xla_tpu_enable_cross_program_prefetch_freeing=false --xla_tpu_enable_dot_strength_reduction=true --xla_tpu_layout_use_dot_grouping=false --xla_tpu_msa_inefficient_use_to_copy_ratio=0.5 --xla_tpu_reduce_loop_fusion_dup_with_unfusable_user=false --xla_tpu_vector_load_fusion_window=1024 --xla_tpu_vector_store_fusion_window=256 --xla_jf_conv_reshape_fusion=false --xla_tpu_input_conv_multi_users=false --xla_tpu_enable_multi_level_input_dot_dot_fusion=false --xla_tpu_enable_multi_level_output_dot_dot_fusion=false --xla_tpu_dot_dot_fusion_separable_convs_only=false --xla_tpu_enable_multi_level_nested_loop_fusion=true --xla_tpu_nested_dot_fusion=true --xla_tpu_enable_multi_level_nested_dot_fusion=false --xla_jf_enable_multi_output_fusion=true --xla_tpu_use_lp_llo_scheduler_for_dot_dot_fusions=false --xla_tpu_enable_flash_attention=true"
python3 -m MaxText.maxengine_server \
MaxText/configs/base.yml \
model_name=gpt-oss-20b \
tokenizer_type=huggingface \
tokenizer_path=openai/gpt-oss-20b \
per_device_batch_size=1 \
max_target_length=1024 \
ici_fsdp_parallelism=1 \
ici_tensor_parallelism=8 \
attention=dot_product \
load_parameters_path=gs://lance-maxtext/gpt-oss-train-after/test_pre_gpt_oss_20b/checkpoints/9/items \
max_prefill_predict_length=128 \
prompt="I love to" \
mla_naive_kvcache=False \
enable_jax_profiler=True hf_access_token=$HF_TOKEN
Step 2: Execute Jetstream
# On terminal 2
JAX_PLATFORMS=tpu python benchmarks/benchmark_serving.py \
--tokenizer=openai/gpt-oss-20b \
--num-prompts 5000 \
--dataset mmlu \
--dataset-path mmlu/data/test/ \
--request-rate 0 \
--warmup-mode sampled \
--save-request-outputs \
--run-eval True \
--use-hf-tokenizer True
Step 3: Output Collection
# On terminal 2
# Defile data saving location
RUN=run-$(date +%Y-%m-%d-%H-%M-%S)
echo $RUN
log_dir=$HOME/test_memory/mistral/$RUN
echo $log_dir
# Collect data
python -m jax.collect_profile 9999 6000 --log_dir=$log_dir --no_perfetto_link
Results:
Decode
Executed Command:
python3 -m MaxText.decode MaxText/configs/base.yml \
model_name=gpt3-6b \
tokenizer_type=huggingface \
tokenizer_path=openai/gpt-oss-20b \
per_device_batch_size=1 \
ici_fsdp_parallelism=2 \
ici_autoregressive_parallelism=4 \
max_prefill_predict_length=128 \
prefill_chunk_size=0 \
prompt="I love to" \
attention=dot_product \
weight_dtype=bfloat16 \
load_parameters_path=gs://lance-maxtext/gpt3-6b-train-before/gpt3-train-run/checkpoints/0/items
For both before and after migration, we received the same error:
Traceback (most recent call last):
File "<frozen runpy>", line 198, in _run_module_as_main
File "<frozen runpy>", line 88, in _run_code
File "/home/wanglance_google_com/maxtext/src/MaxText/decode.py", line 211, in <module>
app.run(main)
File "/home/wanglance_google_com/maxtext_venv/lib/python3.12/site-packages/absl/app.py", line 316, in run
_run_main(main, args)
File "/home/wanglance_google_com/maxtext_venv/lib/python3.12/site-packages/absl/app.py", line 261, in _run_main
sys.exit(main(argv))
^^^^^^^^^^
File "/home/wanglance_google_com/maxtext/src/MaxText/decode.py", line 97, in main
params = engine.load_params(rng_load_params)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/wanglance_google_com/maxtext/src/MaxText/maxengine.py", line 252, in load_params
self.prefill_kv_cache_annotations = maxtext_utils.get_prefill_kv_cache_annotations(
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/wanglance_google_com/maxtext/src/MaxText/maxtext_utils.py", line 1083, in get_prefill_kv_cache_annotations
abstract_state = jax.eval_shape(init_kv_cache_partial)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/wanglance_google_com/maxtext/src/MaxText/maxtext_utils.py", line 1070, in init_kv_cache
model_vars = model.init(
^^^^^^^^^^^
File "/home/wanglance_google_com/maxtext/src/MaxText/layers/models.py", line 64, in init
return nn.Module.init(module, *args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/wanglance_google_com/maxtext/src/MaxText/layers/models.py", line 154, in __call__
logits, hidden_state = self.decoder(
^^^^^^^^^^^^^
File "/home/wanglance_google_com/maxtext/src/MaxText/layers/decoders.py", line 652, in __call__
y = self._apply_embedding(
^^^^^^^^^^^^^^^^^^^^^^
File "/home/wanglance_google_com/maxtext/src/MaxText/layers/decoders.py", line 563, in _apply_embedding
y += embed_as_linen(
^^^^^^^^^^^^^^^
File "/home/wanglance_google_com/maxtext/src/MaxText/layers/nnx_wrappers.py", line 437, in __call__
out = method_fn(module, *args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/wanglance_google_com/maxtext/src/MaxText/layers/embeddings.py", line 143, in __call__
raise ValueError("Input type must be an integer or unsigned integer.")
ValueError: Input type must be an integer or unsigned integer.
@hsuan-lun-chiang do you have train profiles for this run? I see some issues with the Jetstream profiles, maybe they were started before the actual requests started happening?
Hi @bvandermoon , Here are profiles collected with xplane: Before After
command:
python3 -m src.MaxText.train MaxText/configs/base.yml run_name=gpt3-train-run base_output_directory=gs://maxtext-test/train_gpt3/33/ model_name=gpt3-6b dataset_type=synthetic steps=10
@hsuan-lun-chiang do you have train profiles for this run? I see some issues with the Jetstream profiles, maybe they were started before the actual requests started happening?
Hi @bvandermoon , Here are profiles collected with xplane: Before After
command:
python3 -m src.MaxText.train MaxText/configs/base.yml run_name=gpt3-train-run base_output_directory=gs://maxtext-test/train_gpt3/33/ model_name=gpt3-6b dataset_type=synthetic steps=10
Thanks @hsuan-lun-chiang. The profiles look good. I will take one more pass on the PR tomorrow
@hsuan-lun-chiang do you have train profiles for this run? I see some issues with the Jetstream profiles, maybe they were started before the actual requests started happening?
Hi @bvandermoon , Here are profiles collected with xplane: Before After command:
python3 -m src.MaxText.train MaxText/configs/base.yml run_name=gpt3-train-run base_output_directory=gs://maxtext-test/train_gpt3/33/ model_name=gpt3-6b dataset_type=synthetic steps=10Thanks @hsuan-lun-chiang. The profiles look good. I will take one more pass on the PR tomorrow
Thank you! I also rebased the code to the latest version.
Thanks @hsuan-lun-chiang. Looking good, just have a few comments
Thank you @bvandermoon for the comments, I've addressed them and removed the KVCache implementation.
Thanks for the change! At high level, I noticed:
- accuracy is 0,
{'accuracy': 0.0, 'gen_num': 5000}, and potentially you will need to use HF tokenizer during inference- profile are empty. https://xprof.corp.google.com/memory_profile/wanglance-3924113240870141729. You may want to extend a longer time during capture.
A few examples for your reference:
- #2435
- #2430
Thank you @RissyRan for help reviewing this! I've updated the PR description with both the before & after migration profiling results. Regarding the accuracy, Gpt3 is currently experiencing inference issues. As discussed in previous comments, Branden suggest we should remove the KVCache portion (which fix this issue) to focus on ensuring the before/after match for the migration. So we cannot use maxengine to measure inference performance at this time.
This PR has been merged by copybara in the following commit, after resolving some g3 conflicts (cc @SurbhiJainUSC ).
https://github.com/AI-Hypercomputer/maxtext/commit/c7acd26cf0e37d40982c08aa86f561e4a43c48e1
We will close this one if no issues are detected today.
This PR has been merged by copybara in the following commit, after resolving some g3 conflicts (cc @SurbhiJainUSC ).
We will close this one if no issues are detected today.
Thanks @hengtaoguo for the notice. Let me close this one.