maxtext icon indicating copy to clipboard operation
maxtext copied to clipboard

converted mlperf gpt3 ckpt starts with a worse loss

Open gramesh-amd opened this issue 1 year ago • 26 comments

Hello, We converted the paxml checkpoint and resumed training with following config:

base_config: "base.yml"
tokenizer_path: "/dockerx/vocab/c4_en_301_5Mexp2_spm.model"
dataset_type: "tfds"
dataset_path: "/ckpts/c4_mlperf_dataset"
dataset_name: "en:3.0.4"
eval_dataset_name: "en:3.0.5"
split: "train2"
tokenize_eval_data: False
eval_data_column: "ids"
add_bos: False
add_eos: False
eval_split: "validation_tokenized_5662seqs"
eval_interval: 10  # the specific number of train step between eval_step
target_eval_loss: 2.69  # early stop once reaching target eval_loss

enable_checkpointing: True
save_interval_steps: 5

# Args coming from the NVIDIA spreadsheet http://shortn/_W9CzVbtQde and
# third_party/py/maxtext/configs/a3/llama_2_7b.
hardware: "gpu"
steps: 10
model_name: "gpt3-175b" # this model config is unchanged
attention: "cudnn_flash_te"

gradient_accumulation_steps: 1

dcn_data_parallelism: 1
dcn_fsdp_parallelism: -1
dcn_pipeline_parallelism: 1
dcn_tensor_parallelism: 1
dcn_sequence_parallelism: 1
ici_fsdp_parallelism: 8
ici_data_parallelism: 1
ici_sequence_parallelism: 1
ici_tensor_parallelism: 1
ici_pipeline_parallelism: 1
per_device_batch_size: 5
max_target_length: 2048

remat_policy: "full"
use_iota_embed: True
scan_layers: False
async_checkpointing: False
logits_dot_in_fp32: False
megablox: False

dtype: "bfloat16"
quantization: ""
quantize_kvcache: False
kv_quant_axis: "heads_and_dkv"
kv_quant_dtype: "int8"
weight_dtype: bfloat16
checkpoint_is_quantized: False # Set to True if reading from a saved aqt quantized checkpoint

skip_first_n_steps_for_profiler: 3

mesh_axes: ['data', 'stage', 'fsdp', 'fsdp_transpose', 'sequence', 'tensor', 'autoregressive']
logical_axis_rules: [
                      ['activation_batch', ['data', 'fsdp', 'fsdp_transpose',]],
                       # For pipeline parallelism the pre and post decoder layer tensors' batch dimension is sharded by stages.
                       # Microbatches are sharded by stage, so moving out of and into this sharding should be a local reshape.
                       # The "stage" needs to be listed first since the microbatch dimension is first before the reshape.
                      ['activation_embed_and_logits_batch', ['stage', 'data', 'fsdp', 'fsdp_transpose']],
                      ['activation_heads', ['tensor','sequence']],
                      ['activation_kv_heads', ['tensor','sequence']],
                      ['activation_length', 'sequence'],
                      ['activation_embed', 'tensor'],
                      ['activation_mlp', 'tensor'],
                      ['activation_kv', 'tensor'],
                      ['activation_kv_batch', ['data', 'fsdp', 'fsdp_transpose',]],
                      ['activation_kv_head_dim', 'tensor'],
                      ['activation_vocab', ['tensor', 'sequence']],
                      ['activation_vocab', 'tensor'],
                      ['activation_vocab', 'sequence'],
                      ['activation_stage','stage'],
                      ['mlp', ['fsdp_transpose', 'tensor', 'autoregressive']],
                      ['vocab', ['tensor', 'autoregressive']],
                      ['embed', ['fsdp', 'fsdp_transpose', 'sequence']],
                      ['embed', ['fsdp', 'sequence']],
                      ['norm', 'fsdp'],
                      ['heads', ['tensor', 'autoregressive']],
                      ['layers', 'stage'],
                      ['kv', []],
                      ['kv_heads', ['tensor', 'autoregressive']],
                      ['kv_head_dim', []],
                      ['cache_batch', []],
                      ['cache_heads', ['autoregressive', 'tensor']],
                      ['cache_kv', []],
                      ['cache_sequence', []],
                    ]

# Axes used for DCN must be earlier in this list than ICI, see (b/339009148) for details
data_sharding: [['data', 'stage', 'fsdp', 'fsdp_transpose', 'sequence', 'tensor', 'autoregressive']]

The tokenizer and data splits (3.0.4, 3.0.5) were downloaded from mlperf2 bucket. I have also tried using the c4_mlperf dataset_type like this:

base_config: "base.yml"
tokenizer_path: "/dockerx/vocab/c4_en_301_5Mexp2_spm.model"
dataset_type: "c4_mlperf"
dataset_path: "/ckpts/c4_mlperf_dataset"
dataset_name: "en:3.0.4"
eval_dataset_name: "en:3.0.5"
split: "train2"
eval_split: "validation_tokenized_5662seqs"
python maxtext/MaxText/train.py /dockerx/maxtext/MaxText/configs/gpt3_175b_gpu.yml base_output_directory=/ckpts/paxml/gpt3-conversion run_name=gpt3-conversion steps=4010 scan_layers=true

^ scan_layers set to true in line with how we converted the ckpt

completed step: 4000, seconds: 91.772, TFLOP/s/device: 24.021, Tokens/s/device: 22.316, total_weights: 65504, loss: 7.644, perplexity: 2088.295
To see full metrics 'tensorboard --logdir=/ckpts/paxml/gpt3-conversion/gpt3-conversion/tensorboard/'
completed step: 4001, seconds: 12.945, TFLOP/s/device: 170.297, Tokens/s/device: 158.213, total_weights: 65504, loss: 7.687, perplexity: 2179.917
completed step: 4002, seconds: 11.886, TFLOP/s/device: 185.471, Tokens/s/device: 172.310, total_weights: 65504, loss: 7.739, perplexity: 2297.215
completed step: 4003, seconds: 11.885, TFLOP/s/device: 185.479, Tokens/s/device: 172.318, total_weights: 65504, loss: 7.597, perplexity: 1992.680
completed step: 4004, seconds: 11.931, TFLOP/s/device: 184.759, Tokens/s/device: 171.649, total_weights: 65504, loss: 7.680, perplexity: 2165.097
completed step: 4005, seconds: 11.913, TFLOP/s/device: 185.043, Tokens/s/device: 171.912, total_weights: 65504, loss: 7.663, perplexity: 2128.778
completed step: 4006, seconds: 11.945, TFLOP/s/device: 184.546, Tokens/s/device: 171.451, total_weights: 65504, loss: 7.582, perplexity: 1963.248
completed step: 4007, seconds: 11.913, TFLOP/s/device: 185.048, Tokens/s/device: 171.918, total_weights: 65504, loss: 7.648, perplexity: 2096.574
completed step: 4008, seconds: 12.013, TFLOP/s/device: 183.498, Tokens/s/device: 170.478, total_weights: 65504, loss: 7.524, perplexity: 1851.645
completed step: 4009, seconds: 11.920, TFLOP/s/device: 184.929, Tokens/s/device: 171.807, total_weights: 65504, loss: 7.618, perplexity: 2034.629

^ starts with a very high loss and we expected something closer to 2.77

We have ensured that the training loads the right checkpoint, the correct data splits and also the tokenizer from the logs

gramesh-amd avatar Sep 13 '24 03:09 gramesh-amd

@ZhiyuLi-goog thanks again for your help with other issues. Do you see any problems with the config or know why the loss is much higher?

gramesh-amd avatar Sep 13 '24 03:09 gramesh-amd

I have never tried on GPU. To narrow down the root cause, could you try with normal attention?

attention: "dot_product" 

ZhiyuLi-goog avatar Sep 13 '24 04:09 ZhiyuLi-goog

with attention: "dot_product" : completed step: 4000, seconds: 91.772, TFLOP/s/device: 24.021, Tokens/s/device: 22.316, total_weights: 65504, loss: 7.644, perplexity: 2088.295 To see full metrics 'tensorboard --logdir=/ckpts/paxml/gpt3-conversion/gpt3-conversion/tensorboard/' completed step: 4001, seconds: 39.677, TFLOP/s/device: 277.795, Tokens/s/device: 258.083, total_weights: 327520, loss: 7.638, perplexity: 2076.376 completed step: 4002, seconds: 39.883, TFLOP/s/device: 276.359, Tokens/s/device: 256.749, total_weights: 327520, loss: 7.646, perplexity: 2092.290

I get similar loss as before

gramesh-amd avatar Sep 13 '24 16:09 gramesh-amd

Oh, could you try something like

python3 MaxText/train.py MaxText/configs/base.yml run_name="${RUNNAME}" model_name=gpt3-175b

instead of changing the base.yml? You can find the exact model yaml setup gpt3-175b.yml and there's some more setup for gpt3-175b.

# these flags might be relevant to output results
logits_via_embedding: True
normalize_embedding_logits: False
logits_dot_in_fp32: False
normalization_layer_epsilon: 1.e-05
use_iota_embed: True
opt_type: "adam_pax"

I think logits_via_embedding: True should be the most important one.

ZhiyuLi-goog avatar Sep 13 '24 18:09 ZhiyuLi-goog

I tested these out. First running

python3 MaxText/train.py MaxText/configs/base.yml run_name="${RUNNAME}" model_name=gpt3-175b

and then also adding the other relevant flags you posted one by one and all of them start with the bad loss (7.6x). So its not flash attn, tokenizer (as validation is pretokenized and evaluated loss is also bad), config args (as i tried the flags you have suggested)

Its probably something to do with the model weights

gramesh-amd avatar Sep 13 '24 22:09 gramesh-amd

I can take a look at full logs if you have. We should have final effective configs in that log.

ZhiyuLi-goog avatar Sep 13 '24 23:09 ZhiyuLi-goog

maxtext_gpt3_logs.txt

Thanks. Here are the logs

gramesh-amd avatar Sep 17 '24 04:09 gramesh-amd

Checked the log. All updated parameters matched and I didn't find anything suspicious.

ZhiyuLi-goog avatar Sep 17 '24 18:09 ZhiyuLi-goog

Thanks for checking yeah its strange that its starting with a bad loss. I also tried testing the tokenizer and it also seems fine

gramesh-amd avatar Sep 17 '24 18:09 gramesh-amd

The only one I found looks weird is

+ Config param weight_dtype: float32
- Config param weight_dtype: bfloat16

Could you try using weight_dtype as float32 instead of bfloat16? The activation is calculated as bfloat16 while all parameter and optimizer state should be in float32 format for better convergence.

However, I do not expect such a big gap.

ZhiyuLi-goog avatar Sep 17 '24 18:09 ZhiyuLi-goog

Tried the weight_dtype as float32 as well. Same problem

im wondering if we can send you our converted ckpt for you to load and verify its an ckpt problem?

gramesh-amd avatar Sep 17 '24 20:09 gramesh-amd

I can take a try in TPU side

By the way, would it be useful to you to print the mean average of each param state after conversion?

ZhiyuLi-goog avatar Sep 17 '24 21:09 ZhiyuLi-goog

im not sure if it will be useful. We also loaded the pax ckpt directly in paxml and the ckpt starts at the right loss. So at this point, we suspect something is going wrong during conversion

gramesh-amd avatar Sep 20 '24 18:09 gramesh-amd

It would be easiest if you have some converted ckpt, I can directly compare your converted ckpt against ours. If you have some output log in conversion script, I can take a look as well.

We didn't try that in gpu, I guess there might be something differently.

ZhiyuLi-goog avatar Sep 20 '24 20:09 ZhiyuLi-goog

great we will share the converted ckpt and the conversion logs. Do you have a gcloud bucket that i could push it to? or do you recommend some other way?

gramesh-amd avatar Sep 20 '24 20:09 gramesh-amd

we will share the converted ckpt and the conversion logs. Do you have a gcloud bucket that i could push it to? or do you recommend some other way?

Great if you can share with us some open gcloud bucket. By the way, which conversion script are you using? Is the one in mlperf 4.0 submission or the one in maxtext main branch?

ZhiyuLi-goog avatar Sep 20 '24 21:09 ZhiyuLi-goog

ok, let me do that We tried both versions and with both, we are getting the same problem

gramesh-amd avatar Sep 20 '24 21:09 gramesh-amd

We tried both versions and with both, we are getting the same problem

Gotcha, thank you for the info!

ZhiyuLi-goog avatar Sep 20 '24 21:09 ZhiyuLi-goog

We have created the bucket and will share the access with you soon (I got your google email from one of your commits)

gramesh-amd avatar Sep 30 '24 23:09 gramesh-amd

Hello again, You should finally have access to the bucket containing all the ckpts image

We have shared three ckpts: 1.gpt3-conversion-forked/ is ckpt created with mlperf fork 2. gpt3-conversion-noscan-cache/ 3.gpt3-conversion/

The second and third are both the latest branch - the second was scan_layers=false, and the third is scan_layers=true

let us know if you are able to access and if you have any questions Thanks

gramesh-amd avatar Oct 07 '24 21:10 gramesh-amd

Thank you @gramesh-amd

I will test with your ckpt.

ZhiyuLi-goog avatar Oct 07 '24 21:10 ZhiyuLi-goog

I wrote a script to look at the checkpoint that we generated and compare it to the original data. What I found (at least in my initial look) is that for each row has the first 1/4 of its data (starting at element 0) populated, and the rest of it was 0. This is interesting because we used 4 nodes with 8 GPUs, with intra- and inter-node FSDP. This makes me think that we have data from the first node, and 0's from the other three nodes.

Any suggestions for how we can confirm that this is what happened and debug it?

gabeweisz avatar Oct 08 '24 16:10 gabeweisz

@gabeweisz Great finding and thank you for looking into it!

I am just wondering how did you run the script? I do expect to :

  1. have this script running on each device, this is a no-brainer in TPU, and I am wondering if we should do some tweak in GPU
  2. the output_directory should be accessible to each device

key idea

# each process/device will distribute the weight (fsdp) and each device only keep its own shard
result = jax.make_array_from_single_device_arrays(
    shape,
    sharding,
    [jax.device_put(np.array(arr[index]), d) for d, index in sharding.addressable_devices_indices_map(shape).items()],
)
...
...

# distribute saving to a output directory like a gcs bucket which can be accessible by all devices.
if save_checkpoint(checkpoint_manager, converted_state.step, converted_state):
  max_logging.log(f"saved a checkpoint at step {converted_state.step}")
# Upon preemption, exit when and only when all ongoing saves are complete.
if checkpoint_manager.reached_preemption(converted_state.step):
  checkpoint_manager.wait_until_finished()
  sys.exit()

example

I have tried the script yesterday and worked for me.

# checkpoint loading from a fixed folder
RUN_NAME=ckpt
BASE_OUTPUT_DIR=gs://path/to/output

python MaxText/convert_gpt3_ckpt_from_paxml.py \
  --paxml-ckpt-path=gs://mlperf-llm-public2/gpt3_spmd1x64x24_tpuv4-3072_v84_20221101/checkpoints/checkpoint_00004000 \
  --maxtext-model-name=gpt3-175b \
  --run-name=$RUN_NAME \
  --base-output-directory=$BASE_OUTPUT_DIR

Note the output directory is a gcs bucket which can be accessible by all devices.

ZhiyuLi-goog avatar Oct 08 '24 17:10 ZhiyuLi-goog

We ran the script in a way very similar to how you ran it - my colleague Gowtham has shared what we did earlier.

When we ran this, we didn't have a shared NFS big enough for all the nodes and did not have access to a GCS bucket - each node was writing to its own local directory.

I did check, and once the script was finished, only node 0 had a checkpoint - none of the others did.

Do you think this caused the issue? If so, does Orbax have a way to work around this?

Another option is that I can try to modify an on-disk checkpoint using tensorstore as in the documentation - there is no real reason why we need to load the checkpoint onto GPUs to convert it from one format to another.

gabeweisz avatar Oct 08 '24 18:10 gabeweisz

We just found the place in the documentation where orbax says that all nodes need to write to the same filesystem - that explains what went wrong for us.

gabeweisz avatar Oct 08 '24 20:10 gabeweisz

When we ran this, we didn't have a shared NFS big enough for all the nodes and did not have access to a GCS bucket - each node was writing to its own local directory. We just found the place in the documentation where orbax says that all nodes need to write to the same filesystem - that explains what went wrong for us.

Exactly, I think it should be the root cause.

ZhiyuLi-goog avatar Oct 08 '24 20:10 ZhiyuLi-goog

We just found the place in the documentation where orbax says that all nodes need to write to the same filesystem - that explains what went wrong for us.

Hi @gabeweisz, were you able to figure it out? Let me know if you have any further questions. Otherwise, we will close it.

ZhiyuLi-goog avatar Dec 27 '24 21:12 ZhiyuLi-goog

Please go ahead and resolve the ticket

gabeweisz avatar Dec 30 '24 02:12 gabeweisz

Thank you @gabeweisz !

ZhiyuLi-goog avatar Dec 30 '24 02:12 ZhiyuLi-goog