maxtext icon indicating copy to clipboard operation
maxtext copied to clipboard

Fix bad synthetic dataloader with per device batch size < 1.

Open wang2yn84 opened this issue 8 months ago • 0 comments

Description

Fix BadSyntheticDataIterator for grain. The local iterator is missing and workload will error out on when using Grain dataset together with pdb < 1.

Tests

Manually run the following workload: python -m MaxText.train MaxText/configs/base.yml skip_jax_distributed_system=True run_name=lance_test attention=dot_product dataset_type=grain tokenizer_path=assets/tokenizer.llama2 hardware=gpu logits_dot_in_fp32=false enable_goodput_recording=false monitor_goodput=false remat_policy=full weight_dtype=bfloat16 save_config_to_gcs=false scan_layers=false per_device_batch_size=0.25 dcn_fsdp_parallelism=-1 dcn_data_parallelism=1 ici_fsdp_parallelism=1 ici_tensor_parallelism=8 packing=false enable_checkpoint_cloud_logger=true dataset_path=/scratch/lancewang/dataset_pvc/ grain_train_files=/scratch/lancewang/dataset_pvc/array-record/c4/en/3.0.1/c4-train.array_record* grain_worker_count=1 enable_checkpointing=false async_checkpointing=true checkpoint_period=10 save_config_to_gcs=false

Checklist

Before submitting this PR, please make sure (put X in square brackets):

  • [x] I have performed a self-review of my code.
  • [x] I have necessary comments in my code, particularly in hard-to-understand areas.
  • [x] I have run end-to-end tests tests and provided workload links above if applicable.
  • [x] I have made or will make corresponding changes to the doc if needed.

wang2yn84 avatar May 09 '25 00:05 wang2yn84