maxtext icon indicating copy to clipboard operation
maxtext copied to clipboard

Migrate Gpt3 to NNX.

Open hsuan-lun-chiang opened this issue 5 months ago • 7 comments

Description

This PR

  1. Migrate Gpt3 implementation from Linen to NNX.

Including the following classes:

  • Gpt3MultiHeadAttention
  • Gpt3DecoderLayer
  1. Fix the decode function of Gpt3, by
  • Cast decoder_positions to int32 - The trainable position embedding layer (Embed layer) requires integer indices for its lookup, but decoder_positions was passed as a float. This casts it to int32 to prevent the ValueError.

Tests

Ran train command to train gpt3-6b for 10 steps:

python3 -m MaxText.train  MaxText/configs/base.yml run_name=gpt3-train-run base_output_directory=gs://maxtext-test/train_gpt3/1/ model_name=gpt3-6b dataset_type=synthetic steps=10

Logs: Linen, before migration NNX, after migration

Profile: Linen, before migration NNX, after migration

Checklist

Before submitting this PR, please make sure (put X in square brackets):

  • [X] I have performed a self-review of my code.
  • [X] I have necessary comments in my code, particularly in hard-to-understand areas.
  • [X] I have run end-to-end tests tests and provided workload links above if applicable.
  • [X] I have made or will make corresponding changes to the doc if needed.

hsuan-lun-chiang avatar Aug 01 '25 10:08 hsuan-lun-chiang

Awesome to see this @hsuan-lun-chiang. Could you please add before/after logs for the training command shown in your test section? It would be great to see that perf is the same before/after here

Update the description with the before/after logs, thank you.

hsuan-lun-chiang avatar Aug 05 '25 06:08 hsuan-lun-chiang

Results for Train and Jetstream

Test Environment

Machine Type: TPU V6e-8

Train

Executed Command:

python3 -m MaxText.train \
MaxText/configs/base.yml \
run_name=gpt3-train-run \
base_output_directory=gs://lance-maxtext/gpt3-6b-train-before/ \
model_name=gpt3-6b \
dataset_type=synthetic \
steps=10

Results:


Maxengine / Jetstream

Step 1: Launch Maxengine

# On terminal 1
export LIBTPU_INIT_ARGS="--xla_jf_auto_cross_replica_sharding=false --xla_tpu_enable_windowed_einsum_for_reduce_scatter=false --xla_tpu_enable_windowed_einsum_for_all_gather=false --xla_tpu_prefer_latch_optimized_rhs_layouts=true --xla_tpu_enable_experimental_fusion_cost_model=false --xla_tpu_dot_dot_fusion_duplicated=false --xla_tpu_dot_dot_fusion=true --xla_jf_conv_input_fusion=true --xla_jf_conv_output_fusion=true --xla_tpu_rwb_fusion=false --xla_tpu_copy_fusion_pad_unpad_ratio=0 --xla_tpu_licm_size_inflation_ratio=1 --xla_tpu_copy_elision_analysis_allowance=150000 --xla_tpu_copy_insertion_use_region_analysis_limit=10000 --xla_tpu_order_dot_after_layout=true --xla_jf_rematerialization_percent_shared_memory_limit=100 --xla_tpu_use_repeated_instance_for_preferred_prefetch_time=true --xla_tpu_enforce_prefetch_fifo_order=false --xla_tpu_prefetch_interval_picker_size_override=6000000 --xla_tpu_async_copy_bandwidth_scaling_factor=1 --xla_tpu_nd_short_transfer_max_chunks=-1 --xla_tpu_enable_aggressive_broadcast_priority_update=true --xla_tpu_alternate_memory_benefit_scaling_factor_for_large_buffers=SQRT --xla_tpu_memory_bound_loop_optimizer_options=enabled:true --xla_tpu_enable_copy_fusion=true --xla_tpu_enable_cross_program_prefetch_freeing=false --xla_tpu_enable_dot_strength_reduction=true --xla_tpu_layout_use_dot_grouping=false --xla_tpu_msa_inefficient_use_to_copy_ratio=0.5 --xla_tpu_reduce_loop_fusion_dup_with_unfusable_user=false --xla_tpu_vector_load_fusion_window=1024 --xla_tpu_vector_store_fusion_window=256 --xla_jf_conv_reshape_fusion=false --xla_tpu_input_conv_multi_users=false --xla_tpu_enable_multi_level_input_dot_dot_fusion=false --xla_tpu_enable_multi_level_output_dot_dot_fusion=false --xla_tpu_dot_dot_fusion_separable_convs_only=false --xla_tpu_enable_multi_level_nested_loop_fusion=true --xla_tpu_nested_dot_fusion=true --xla_tpu_enable_multi_level_nested_dot_fusion=false --xla_jf_enable_multi_output_fusion=true --xla_tpu_use_lp_llo_scheduler_for_dot_dot_fusions=false --xla_tpu_enable_flash_attention=true"

python3 -m MaxText.maxengine_server \
MaxText/configs/base.yml \
model_name=gpt-oss-20b \
tokenizer_type=huggingface \
tokenizer_path=openai/gpt-oss-20b \
per_device_batch_size=1 \
max_target_length=1024 \
ici_fsdp_parallelism=1 \
ici_tensor_parallelism=8 \
attention=dot_product \
load_parameters_path=gs://lance-maxtext/gpt-oss-train-after/test_pre_gpt_oss_20b/checkpoints/9/items \
max_prefill_predict_length=128 \
prompt="I love to" \
mla_naive_kvcache=False \
enable_jax_profiler=True hf_access_token=$HF_TOKEN

Step 2: Execute Jetstream

# On terminal 2
JAX_PLATFORMS=tpu python benchmarks/benchmark_serving.py \
--tokenizer=openai/gpt-oss-20b \
--num-prompts 5000 \
--dataset mmlu \
--dataset-path mmlu/data/test/ \
--request-rate 0 \
--warmup-mode sampled \
--save-request-outputs \
--run-eval True \
--use-hf-tokenizer True

Step 3: Output Collection

# On terminal 2
# Defile data saving location
RUN=run-$(date +%Y-%m-%d-%H-%M-%S)
echo $RUN
log_dir=$HOME/test_memory/mistral/$RUN
echo $log_dir

# Collect data
python -m jax.collect_profile 9999 6000 --log_dir=$log_dir --no_perfetto_link

Results:


Decode

Executed Command:

python3 -m MaxText.decode MaxText/configs/base.yml \
model_name=gpt3-6b \
tokenizer_type=huggingface \
tokenizer_path=openai/gpt-oss-20b \
per_device_batch_size=1 \
ici_fsdp_parallelism=2 \
ici_autoregressive_parallelism=4 \
max_prefill_predict_length=128 \
prefill_chunk_size=0 \
prompt="I love to" \
attention=dot_product \
weight_dtype=bfloat16 \
load_parameters_path=gs://lance-maxtext/gpt3-6b-train-before/gpt3-train-run/checkpoints/0/items

For both before and after migration, we received the same error:

Traceback (most recent call last):
  File "<frozen runpy>", line 198, in _run_module_as_main
  File "<frozen runpy>", line 88, in _run_code
  File "/home/wanglance_google_com/maxtext/src/MaxText/decode.py", line 211, in <module>
    app.run(main)
  File "/home/wanglance_google_com/maxtext_venv/lib/python3.12/site-packages/absl/app.py", line 316, in run
    _run_main(main, args)
  File "/home/wanglance_google_com/maxtext_venv/lib/python3.12/site-packages/absl/app.py", line 261, in _run_main
    sys.exit(main(argv))
             ^^^^^^^^^^
  File "/home/wanglance_google_com/maxtext/src/MaxText/decode.py", line 97, in main
    params = engine.load_params(rng_load_params)
             ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/wanglance_google_com/maxtext/src/MaxText/maxengine.py", line 252, in load_params
    self.prefill_kv_cache_annotations = maxtext_utils.get_prefill_kv_cache_annotations(
                                        ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/wanglance_google_com/maxtext/src/MaxText/maxtext_utils.py", line 1083, in get_prefill_kv_cache_annotations
    abstract_state = jax.eval_shape(init_kv_cache_partial)
                     ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/wanglance_google_com/maxtext/src/MaxText/maxtext_utils.py", line 1070, in init_kv_cache
    model_vars = model.init(
                 ^^^^^^^^^^^
  File "/home/wanglance_google_com/maxtext/src/MaxText/layers/models.py", line 64, in init
    return nn.Module.init(module, *args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/wanglance_google_com/maxtext/src/MaxText/layers/models.py", line 154, in __call__
    logits, hidden_state = self.decoder(
                           ^^^^^^^^^^^^^
  File "/home/wanglance_google_com/maxtext/src/MaxText/layers/decoders.py", line 652, in __call__
    y = self._apply_embedding(
        ^^^^^^^^^^^^^^^^^^^^^^
  File "/home/wanglance_google_com/maxtext/src/MaxText/layers/decoders.py", line 563, in _apply_embedding
    y += embed_as_linen(
         ^^^^^^^^^^^^^^^
  File "/home/wanglance_google_com/maxtext/src/MaxText/layers/nnx_wrappers.py", line 437, in __call__
    out = method_fn(module, *args, **kwargs)
          ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/wanglance_google_com/maxtext/src/MaxText/layers/embeddings.py", line 143, in __call__
    raise ValueError("Input type must be an integer or unsigned integer.")
ValueError: Input type must be an integer or unsigned integer.

ecnal-cienet avatar Oct 09 '25 22:10 ecnal-cienet

@hsuan-lun-chiang do you have train profiles for this run? I see some issues with the Jetstream profiles, maybe they were started before the actual requests started happening?

Hi @bvandermoon , Here are profiles collected with xplane: Before After

command: python3 -m src.MaxText.train MaxText/configs/base.yml run_name=gpt3-train-run base_output_directory=gs://maxtext-test/train_gpt3/33/ model_name=gpt3-6b dataset_type=synthetic steps=10

hsuan-lun-chiang avatar Nov 04 '25 11:11 hsuan-lun-chiang

@hsuan-lun-chiang do you have train profiles for this run? I see some issues with the Jetstream profiles, maybe they were started before the actual requests started happening?

Hi @bvandermoon , Here are profiles collected with xplane: Before After

command: python3 -m src.MaxText.train MaxText/configs/base.yml run_name=gpt3-train-run base_output_directory=gs://maxtext-test/train_gpt3/33/ model_name=gpt3-6b dataset_type=synthetic steps=10

Thanks @hsuan-lun-chiang. The profiles look good. I will take one more pass on the PR tomorrow

bvandermoon avatar Nov 05 '25 07:11 bvandermoon

@hsuan-lun-chiang do you have train profiles for this run? I see some issues with the Jetstream profiles, maybe they were started before the actual requests started happening?

Hi @bvandermoon , Here are profiles collected with xplane: Before After command: python3 -m src.MaxText.train MaxText/configs/base.yml run_name=gpt3-train-run base_output_directory=gs://maxtext-test/train_gpt3/33/ model_name=gpt3-6b dataset_type=synthetic steps=10

Thanks @hsuan-lun-chiang. The profiles look good. I will take one more pass on the PR tomorrow

Thank you! I also rebased the code to the latest version.

hsuan-lun-chiang avatar Nov 06 '25 01:11 hsuan-lun-chiang

Thanks @hsuan-lun-chiang. Looking good, just have a few comments

Thank you @bvandermoon for the comments, I've addressed them and removed the KVCache implementation.

hsuan-lun-chiang avatar Nov 11 '25 08:11 hsuan-lun-chiang

Thanks for the change! At high level, I noticed:

  1. accuracy is 0, {'accuracy': 0.0, 'gen_num': 5000}, and potentially you will need to use HF tokenizer during inference
  2. profile are empty. https://xprof.corp.google.com/memory_profile/wanglance-3924113240870141729. You may want to extend a longer time during capture.

A few examples for your reference:

  • #2435
  • #2430

Thank you @RissyRan for help reviewing this! I've updated the PR description with both the before & after migration profiling results. Regarding the accuracy, Gpt3 is currently experiencing inference issues. As discussed in previous comments, Branden suggest we should remove the KVCache portion (which fix this issue) to focus on ensuring the before/after match for the migration. So we cannot use maxengine to measure inference performance at this time.

Profile: Linen, before migration NNX, after migration

hsuan-lun-chiang avatar Nov 20 '25 11:11 hsuan-lun-chiang

This PR has been merged by copybara in the following commit, after resolving some g3 conflicts (cc @SurbhiJainUSC ).

https://github.com/AI-Hypercomputer/maxtext/commit/c7acd26cf0e37d40982c08aa86f561e4a43c48e1

We will close this one if no issues are detected today.

hengtaoguo avatar Dec 02 '25 20:12 hengtaoguo

This PR has been merged by copybara in the following commit, after resolving some g3 conflicts (cc @SurbhiJainUSC ).

c7acd26

We will close this one if no issues are detected today.

Thanks @hengtaoguo for the notice. Let me close this one.

xibinliu avatar Dec 10 '25 00:12 xibinliu