converted mlperf gpt3 ckpt starts with a worse loss
Hello, We converted the paxml checkpoint and resumed training with following config:
base_config: "base.yml"
tokenizer_path: "/dockerx/vocab/c4_en_301_5Mexp2_spm.model"
dataset_type: "tfds"
dataset_path: "/ckpts/c4_mlperf_dataset"
dataset_name: "en:3.0.4"
eval_dataset_name: "en:3.0.5"
split: "train2"
tokenize_eval_data: False
eval_data_column: "ids"
add_bos: False
add_eos: False
eval_split: "validation_tokenized_5662seqs"
eval_interval: 10 # the specific number of train step between eval_step
target_eval_loss: 2.69 # early stop once reaching target eval_loss
enable_checkpointing: True
save_interval_steps: 5
# Args coming from the NVIDIA spreadsheet http://shortn/_W9CzVbtQde and
# third_party/py/maxtext/configs/a3/llama_2_7b.
hardware: "gpu"
steps: 10
model_name: "gpt3-175b" # this model config is unchanged
attention: "cudnn_flash_te"
gradient_accumulation_steps: 1
dcn_data_parallelism: 1
dcn_fsdp_parallelism: -1
dcn_pipeline_parallelism: 1
dcn_tensor_parallelism: 1
dcn_sequence_parallelism: 1
ici_fsdp_parallelism: 8
ici_data_parallelism: 1
ici_sequence_parallelism: 1
ici_tensor_parallelism: 1
ici_pipeline_parallelism: 1
per_device_batch_size: 5
max_target_length: 2048
remat_policy: "full"
use_iota_embed: True
scan_layers: False
async_checkpointing: False
logits_dot_in_fp32: False
megablox: False
dtype: "bfloat16"
quantization: ""
quantize_kvcache: False
kv_quant_axis: "heads_and_dkv"
kv_quant_dtype: "int8"
weight_dtype: bfloat16
checkpoint_is_quantized: False # Set to True if reading from a saved aqt quantized checkpoint
skip_first_n_steps_for_profiler: 3
mesh_axes: ['data', 'stage', 'fsdp', 'fsdp_transpose', 'sequence', 'tensor', 'autoregressive']
logical_axis_rules: [
['activation_batch', ['data', 'fsdp', 'fsdp_transpose',]],
# For pipeline parallelism the pre and post decoder layer tensors' batch dimension is sharded by stages.
# Microbatches are sharded by stage, so moving out of and into this sharding should be a local reshape.
# The "stage" needs to be listed first since the microbatch dimension is first before the reshape.
['activation_embed_and_logits_batch', ['stage', 'data', 'fsdp', 'fsdp_transpose']],
['activation_heads', ['tensor','sequence']],
['activation_kv_heads', ['tensor','sequence']],
['activation_length', 'sequence'],
['activation_embed', 'tensor'],
['activation_mlp', 'tensor'],
['activation_kv', 'tensor'],
['activation_kv_batch', ['data', 'fsdp', 'fsdp_transpose',]],
['activation_kv_head_dim', 'tensor'],
['activation_vocab', ['tensor', 'sequence']],
['activation_vocab', 'tensor'],
['activation_vocab', 'sequence'],
['activation_stage','stage'],
['mlp', ['fsdp_transpose', 'tensor', 'autoregressive']],
['vocab', ['tensor', 'autoregressive']],
['embed', ['fsdp', 'fsdp_transpose', 'sequence']],
['embed', ['fsdp', 'sequence']],
['norm', 'fsdp'],
['heads', ['tensor', 'autoregressive']],
['layers', 'stage'],
['kv', []],
['kv_heads', ['tensor', 'autoregressive']],
['kv_head_dim', []],
['cache_batch', []],
['cache_heads', ['autoregressive', 'tensor']],
['cache_kv', []],
['cache_sequence', []],
]
# Axes used for DCN must be earlier in this list than ICI, see (b/339009148) for details
data_sharding: [['data', 'stage', 'fsdp', 'fsdp_transpose', 'sequence', 'tensor', 'autoregressive']]
The tokenizer and data splits (3.0.4, 3.0.5) were downloaded from mlperf2 bucket. I have also tried using the c4_mlperf dataset_type like this:
base_config: "base.yml"
tokenizer_path: "/dockerx/vocab/c4_en_301_5Mexp2_spm.model"
dataset_type: "c4_mlperf"
dataset_path: "/ckpts/c4_mlperf_dataset"
dataset_name: "en:3.0.4"
eval_dataset_name: "en:3.0.5"
split: "train2"
eval_split: "validation_tokenized_5662seqs"
python maxtext/MaxText/train.py /dockerx/maxtext/MaxText/configs/gpt3_175b_gpu.yml base_output_directory=/ckpts/paxml/gpt3-conversion run_name=gpt3-conversion steps=4010 scan_layers=true
^ scan_layers set to true in line with how we converted the ckpt
completed step: 4000, seconds: 91.772, TFLOP/s/device: 24.021, Tokens/s/device: 22.316, total_weights: 65504, loss: 7.644, perplexity: 2088.295
To see full metrics 'tensorboard --logdir=/ckpts/paxml/gpt3-conversion/gpt3-conversion/tensorboard/'
completed step: 4001, seconds: 12.945, TFLOP/s/device: 170.297, Tokens/s/device: 158.213, total_weights: 65504, loss: 7.687, perplexity: 2179.917
completed step: 4002, seconds: 11.886, TFLOP/s/device: 185.471, Tokens/s/device: 172.310, total_weights: 65504, loss: 7.739, perplexity: 2297.215
completed step: 4003, seconds: 11.885, TFLOP/s/device: 185.479, Tokens/s/device: 172.318, total_weights: 65504, loss: 7.597, perplexity: 1992.680
completed step: 4004, seconds: 11.931, TFLOP/s/device: 184.759, Tokens/s/device: 171.649, total_weights: 65504, loss: 7.680, perplexity: 2165.097
completed step: 4005, seconds: 11.913, TFLOP/s/device: 185.043, Tokens/s/device: 171.912, total_weights: 65504, loss: 7.663, perplexity: 2128.778
completed step: 4006, seconds: 11.945, TFLOP/s/device: 184.546, Tokens/s/device: 171.451, total_weights: 65504, loss: 7.582, perplexity: 1963.248
completed step: 4007, seconds: 11.913, TFLOP/s/device: 185.048, Tokens/s/device: 171.918, total_weights: 65504, loss: 7.648, perplexity: 2096.574
completed step: 4008, seconds: 12.013, TFLOP/s/device: 183.498, Tokens/s/device: 170.478, total_weights: 65504, loss: 7.524, perplexity: 1851.645
completed step: 4009, seconds: 11.920, TFLOP/s/device: 184.929, Tokens/s/device: 171.807, total_weights: 65504, loss: 7.618, perplexity: 2034.629
^ starts with a very high loss and we expected something closer to 2.77
We have ensured that the training loads the right checkpoint, the correct data splits and also the tokenizer from the logs
@ZhiyuLi-goog thanks again for your help with other issues. Do you see any problems with the config or know why the loss is much higher?
I have never tried on GPU. To narrow down the root cause, could you try with normal attention?
attention: "dot_product"
with attention: "dot_product" : completed step: 4000, seconds: 91.772, TFLOP/s/device: 24.021, Tokens/s/device: 22.316, total_weights: 65504, loss: 7.644, perplexity: 2088.295 To see full metrics 'tensorboard --logdir=/ckpts/paxml/gpt3-conversion/gpt3-conversion/tensorboard/' completed step: 4001, seconds: 39.677, TFLOP/s/device: 277.795, Tokens/s/device: 258.083, total_weights: 327520, loss: 7.638, perplexity: 2076.376 completed step: 4002, seconds: 39.883, TFLOP/s/device: 276.359, Tokens/s/device: 256.749, total_weights: 327520, loss: 7.646, perplexity: 2092.290
I get similar loss as before
Oh, could you try something like
python3 MaxText/train.py MaxText/configs/base.yml run_name="${RUNNAME}" model_name=gpt3-175b
instead of changing the base.yml? You can find the exact model yaml setup gpt3-175b.yml and there's some more setup for gpt3-175b.
# these flags might be relevant to output results
logits_via_embedding: True
normalize_embedding_logits: False
logits_dot_in_fp32: False
normalization_layer_epsilon: 1.e-05
use_iota_embed: True
opt_type: "adam_pax"
I think logits_via_embedding: True should be the most important one.
I tested these out. First running
python3 MaxText/train.py MaxText/configs/base.yml run_name="${RUNNAME}" model_name=gpt3-175b
and then also adding the other relevant flags you posted one by one and all of them start with the bad loss (7.6x). So its not flash attn, tokenizer (as validation is pretokenized and evaluated loss is also bad), config args (as i tried the flags you have suggested)
Its probably something to do with the model weights
I can take a look at full logs if you have. We should have final effective configs in that log.
Checked the log. All updated parameters matched and I didn't find anything suspicious.
Thanks for checking yeah its strange that its starting with a bad loss. I also tried testing the tokenizer and it also seems fine
The only one I found looks weird is
+ Config param weight_dtype: float32
- Config param weight_dtype: bfloat16
Could you try using weight_dtype as float32 instead of bfloat16? The activation is calculated as bfloat16 while all parameter and optimizer state should be in float32 format for better convergence.
However, I do not expect such a big gap.
Tried the weight_dtype as float32 as well. Same problem
im wondering if we can send you our converted ckpt for you to load and verify its an ckpt problem?
I can take a try in TPU side
By the way, would it be useful to you to print the mean average of each param state after conversion?
im not sure if it will be useful. We also loaded the pax ckpt directly in paxml and the ckpt starts at the right loss. So at this point, we suspect something is going wrong during conversion
It would be easiest if you have some converted ckpt, I can directly compare your converted ckpt against ours. If you have some output log in conversion script, I can take a look as well.
We didn't try that in gpu, I guess there might be something differently.
great we will share the converted ckpt and the conversion logs. Do you have a gcloud bucket that i could push it to? or do you recommend some other way?
we will share the converted ckpt and the conversion logs. Do you have a gcloud bucket that i could push it to? or do you recommend some other way?
Great if you can share with us some open gcloud bucket. By the way, which conversion script are you using? Is the one in mlperf 4.0 submission or the one in maxtext main branch?
ok, let me do that We tried both versions and with both, we are getting the same problem
We tried both versions and with both, we are getting the same problem
Gotcha, thank you for the info!
We have created the bucket and will share the access with you soon (I got your google email from one of your commits)
Hello again,
You should finally have access to the bucket containing all the ckpts
We have shared three ckpts: 1.gpt3-conversion-forked/ is ckpt created with mlperf fork 2. gpt3-conversion-noscan-cache/ 3.gpt3-conversion/
The second and third are both the latest branch - the second was scan_layers=false, and the third is scan_layers=true
let us know if you are able to access and if you have any questions Thanks
Thank you @gramesh-amd
I will test with your ckpt.
I wrote a script to look at the checkpoint that we generated and compare it to the original data. What I found (at least in my initial look) is that for each row has the first 1/4 of its data (starting at element 0) populated, and the rest of it was 0. This is interesting because we used 4 nodes with 8 GPUs, with intra- and inter-node FSDP. This makes me think that we have data from the first node, and 0's from the other three nodes.
Any suggestions for how we can confirm that this is what happened and debug it?
@gabeweisz Great finding and thank you for looking into it!
I am just wondering how did you run the script? I do expect to :
- have this script running on each device, this is a no-brainer in TPU, and I am wondering if we should do some tweak in GPU
- the output_directory should be accessible to each device
key idea
# each process/device will distribute the weight (fsdp) and each device only keep its own shard
result = jax.make_array_from_single_device_arrays(
shape,
sharding,
[jax.device_put(np.array(arr[index]), d) for d, index in sharding.addressable_devices_indices_map(shape).items()],
)
...
...
# distribute saving to a output directory like a gcs bucket which can be accessible by all devices.
if save_checkpoint(checkpoint_manager, converted_state.step, converted_state):
max_logging.log(f"saved a checkpoint at step {converted_state.step}")
# Upon preemption, exit when and only when all ongoing saves are complete.
if checkpoint_manager.reached_preemption(converted_state.step):
checkpoint_manager.wait_until_finished()
sys.exit()
example
I have tried the script yesterday and worked for me.
# checkpoint loading from a fixed folder
RUN_NAME=ckpt
BASE_OUTPUT_DIR=gs://path/to/output
python MaxText/convert_gpt3_ckpt_from_paxml.py \
--paxml-ckpt-path=gs://mlperf-llm-public2/gpt3_spmd1x64x24_tpuv4-3072_v84_20221101/checkpoints/checkpoint_00004000 \
--maxtext-model-name=gpt3-175b \
--run-name=$RUN_NAME \
--base-output-directory=$BASE_OUTPUT_DIR
Note the output directory is a gcs bucket which can be accessible by all devices.
We ran the script in a way very similar to how you ran it - my colleague Gowtham has shared what we did earlier.
When we ran this, we didn't have a shared NFS big enough for all the nodes and did not have access to a GCS bucket - each node was writing to its own local directory.
I did check, and once the script was finished, only node 0 had a checkpoint - none of the others did.
Do you think this caused the issue? If so, does Orbax have a way to work around this?
Another option is that I can try to modify an on-disk checkpoint using tensorstore as in the documentation - there is no real reason why we need to load the checkpoint onto GPUs to convert it from one format to another.
We just found the place in the documentation where orbax says that all nodes need to write to the same filesystem - that explains what went wrong for us.
When we ran this, we didn't have a shared NFS big enough for all the nodes and did not have access to a GCS bucket - each node was writing to its own local directory. We just found the place in the documentation where orbax says that all nodes need to write to the same filesystem - that explains what went wrong for us.
Exactly, I think it should be the root cause.
We just found the place in the documentation where orbax says that all nodes need to write to the same filesystem - that explains what went wrong for us.
Hi @gabeweisz, were you able to figure it out? Let me know if you have any further questions. Otherwise, we will close it.
Please go ahead and resolve the ticket
Thank you @gabeweisz !