NeMo icon indicating copy to clipboard operation
NeMo copied to clipboard

Adjusting "global_batch_size" and "micro_batch_size" has no impact on how long each training step takes when using HFAutoModel.

Open jiuqiant opened this issue 7 months ago • 2 comments

Describe the bug

The training step time remains constant for Gemma3 HFAutoModel and MockDataModule, regardless of the "global_batch_size" and "micro_batch_size" values set in MockDataModule.

Steps/Code to reproduce bug

The "gemma3_automodel_test.py":

import os
import sys

import argparse
from nemo.collections import llm
from nemo import lightning as nl
from dataclasses import dataclass, asdict

from nemo.collections.llm.gpt.model.llama import *
from nemo.collections.common.tokenizers.huggingface.auto_tokenizer import AutoTokenizer
from nemo.collections.llm.gpt.data import PreTrainingDataModule
from nemo.collections.llm.gpt.data.mock import MockDataModule
from nemo.collections.llm.gpt.data.hf_dataset import HFMockDataModule
from megatron.core.optimizer import OptimizerConfig
from nemo.collections.common.metrics.perf_metrics import FLOPsMeasurementCallback
from nemo.utils.callbacks import NeMoModelCheckpoint
from nemo.utils.exp_manager import TimingCallback
import lightning.pytorch as pl
from nemo.lightning.pytorch.accelerate.transformer_engine import TEConfig
from nemo.lightning.pytorch.callbacks import MemoryProfileCallback
from nemo.lightning.pytorch.callbacks import NsysCallback

from pytorch_lightning.loggers import WandbLogger, TensorBoardLogger
from lightning.pytorch import seed_everything
import torch
import fiddle as fdl

HF_MODEL_NAME = "google/gemma-3-4b-it"

SEQ_LENGTH = 8192

def main():
    max_sequence_length = 8192
    num_nodes = 1

    model = llm.HFAutoModelForCausalLM(
        model_name=HF_MODEL_NAME,
        model_accelerator=None,
        trust_remote_code=True,
        use_liger_kernel=True,
    )
    
    tokenizer = AutoTokenizer(HF_MODEL_NAME)
    data = llm.MockDataModule(
        num_train_samples=1000000,
        seq_length=max_sequence_length,
        global_batch_size=128, # tried 1, 8, 32, 64 ...
        micro_batch_size=1, # tried 1, 2, 4, 8 ...
    )

    strategy = nl.FSDP2Strategy(
        data_parallel_size=8 * num_nodes,
        tensor_parallel_size=1,
        checkpoint_io=model.make_checkpoint_io(adapter_only=True),
        offload_policy=None,
    )

    trainer = nl.Trainer(
        accelerator="gpu",
        devices=8,
        num_nodes=num_nodes,
        strategy=strategy,
        plugins=nl.MegatronMixedPrecision(precision="bf16-mixed"),
        max_epochs=None,
        max_steps=20,
        log_every_n_steps=1,
        limit_val_batches=2,
        limit_test_batches=2,
        callbacks=[TimingCallback()],
    )

    opt=fdl.build(llm.adam.pytorch_adam_with_flat_lr(lr=3e-4))

    # Setup checkpoint and tensorboard for logger
    ckpt = nl.ModelCheckpoint(
      save_top_k=1,
      save_last=True,
      save_optim_on_train_end=False,
      filename="{val_loss:.2f}-{step}-{consumed_samples}",
    )
    tb = TensorBoardLogger(
      save_dir="tensorboard",
      name="",
    )
    logger = logger = nl.NeMoLogger(
      explicit_log_dir="/logs",
      log_global_rank_0_only=True,
      update_logger_directory=True,
      ckpt=ckpt,
      tensorboard=tb,
    )

    resume = nl.AutoResume(
      resume_if_exists=True,
      resume_ignore_no_checkpoint=True,
    )

    llm.pretrain(
        model=model,
        data=data, 
        trainer=trainer,
        log=logger,
        optim=opt,
        resume=resume,
    )


if __name__ == "__main__":
  main()

cmd to run:

docker run --gpus all -it --rm -v  /home/test_run:/workspace/test_run --shm-size=8g --ulimit memlock=-1 --ulimit stack=67108864 -e CUDA_DEVICE_MAX_CONNECTIONS=1 nvcr.io/nvidia/nemo:25.04  /bin/bash -c "HF_TOKEN=hf_<your id> python test_run/gemma3_automodel_test.py"

I consistently observe the same training step time in the logs, regardless of the "global_batch_size" and "micro_batch_size" settings. On a node with 8 H100s:

Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]
  | Name  | Type                               | Params | Mode 
---------------------------------------------------------------------
0 | model | FSDPGemma3ForConditionalGeneration | 4.3 B  | train
---------------------------------------------------------------------
4.3 B     Trainable params
0         Non-trainable params
4.3 B     Total params
17,200.318Total estimated model params size (MB)
888       Modules in train mode
0         Modules in eval mode

Loading checkpoint shards:  50%|█████     | 1/2 [00:00<00:00,  6.32it/s]
Loading checkpoint shards: 100%|██████████| 2/2 [00:00<00:00, 10.44it/s]
LOCAL_RANK: 7 - CUDA_VISIBLE_DEVICES: [0,1,2,3,4,5,6,7]

Sanity Checking: |          | 0/? [00:00<?, ?it/s]
Sanity Checking:   0%|          | 0/2 [00:00<?, ?it/s]
Sanity Checking DataLoader 0:   0%|          | 0/2 [00:00<?, ?it/s]
Sanity Checking DataLoader 0:  50%|█████     | 1/2 [00:06<00:06,  0.16it/s]
Sanity Checking DataLoader 0: 100%|██████████| 2/2 [00:08<00:00,  0.23it/s]
                                                                           

Training: |          | 0/? [00:00<?, ?it/s]
Training:   0%|          | 0/125000 [00:00<?, ?it/s]
Epoch 0:   0%|          | 0/125000 [00:00<?, ?it/s] Current pytorch-triton version: 3.2.0+gitb2684bf3b.nvinternal, Required triton version: 3.2.0

Epoch 0:   0%|          | 1/125000 [00:05<198:49:05,  0.17it/s]
Epoch 0:   0%|          | 1/125000 [00:05<198:50:07,  0.17it/s, v_num=0, global_step=0.000, reduced_train_loss=14.00, tps=1.2e+4, lr=0.0003, train_step_timing in s=5.720]
Epoch 0:   0%|          | 2/125000 [00:07<138:43:01,  0.25it/s, v_num=0, global_step=0.000, reduced_train_loss=14.00, tps=1.2e+4, lr=0.0003, train_step_timing in s=5.720]
Epoch 0:   0%|          | 2/125000 [00:07<138:43:20,  0.25it/s, v_num=0, global_step=1.000, reduced_train_loss=14.90, tps=34084.0, lr=0.0003, train_step_timing in s=2.260]
Epoch 0:   0%|          | 3/125000 [00:10<118:47:45,  0.29it/s, v_num=0, global_step=1.000, reduced_train_loss=14.90, tps=34084.0, lr=0.0003, train_step_timing in s=2.260]
Epoch 0:   0%|          | 3/125000 [00:10<118:47:58,  0.29it/s, v_num=0, global_step=2.000, reduced_train_loss=14.90, tps=28905.0, lr=0.0003, train_step_timing in s=2.270]
Epoch 0:   0%|          | 4/125000 [00:12<108:50:50,  0.32it/s, v_num=0, global_step=2.000, reduced_train_loss=14.90, tps=28905.0, lr=0.0003, train_step_timing in s=2.270]
Epoch 0:   0%|          | 4/125000 [00:12<108:51:00,  0.32it/s, v_num=0, global_step=3.000, reduced_train_loss=14.10, tps=28760.0, lr=0.0003, train_step_timing in s=2.270]
Epoch 0:   0%|          | 5/125000 [00:14<102:53:18,  0.34it/s, v_num=0, global_step=3.000, reduced_train_loss=14.10, tps=28760.0, lr=0.0003, train_step_timing in s=2.270]
Epoch 0:   0%|          | 5/125000 [00:14<102:53:27,  0.34it/s, v_num=0, global_step=4.000, reduced_train_loss=17.70, tps=28734.0, lr=0.0003, train_step_timing in s=2.270]
Epoch 0:   0%|          | 6/125000 [00:17<99:02:51,  0.35it/s, v_num=0, global_step=4.000, reduced_train_loss=17.70, tps=28734.0, lr=0.0003, train_step_timing in s=2.270] 
Epoch 0:   0%|          | 6/125000 [00:17<99:02:58,  0.35it/s, v_num=0, global_step=5.000, reduced_train_loss=14.40, tps=28522.0, lr=0.0003, train_step_timing in s=2.300]
Epoch 0:   0%|          | 7/125000 [00:19<96:17:52,  0.36it/s, v_num=0, global_step=5.000, reduced_train_loss=14.40, tps=28522.0, lr=0.0003, train_step_timing in s=2.300]
Epoch 0:   0%|          | 7/125000 [00:19<96:17:58,  0.36it/s, v_num=0, global_step=6.000, reduced_train_loss=13.70, tps=28514.0, lr=0.0003, train_step_timing in s=2.300]
Epoch 0:   0%|          | 8/125000 [00:21<94:17:27,  0.37it/s, v_num=0, global_step=6.000, reduced_train_loss=13.70, tps=28514.0, lr=0.0003, train_step_timing in s=2.300]
Epoch 0:   0%|          | 8/125000 [00:21<94:17:33,  0.37it/s, v_num=0, global_step=7.000, reduced_train_loss=14.50, tps=28361.0, lr=0.0003, train_step_timing in s=2.310]
Epoch 0:   0%|          | 9/125000 [00:24<92:43:17,  0.37it/s, v_num=0, global_step=7.000, reduced_train_loss=14.50, tps=28361.0, lr=0.0003, train_step_timing in s=2.310]
Epoch 0:   0%|          | 9/125000 [00:24<92:43:21,  0.37it/s, v_num=0, global_step=8.000, reduced_train_loss=14.50, tps=28349.0, lr=0.0003, train_step_timing in s=2.310]
Epoch 0:   0%|          | 10/125000 [00:26<91:25:34,  0.38it/s, v_num=0, global_step=8.000, reduced_train_loss=14.50, tps=28349.0, lr=0.0003, train_step_timing in s=2.310]
Epoch 0:   0%|          | 10/125000 [00:26<91:25:38,  0.38it/s, v_num=0, global_step=9.000, reduced_train_loss=15.00, tps=28524.0, lr=0.0003, train_step_timing in s=2.290]
Epoch 0:   0%|          | 10/125000 [00:26<91:25:57,  0.38it/s, v_num=0, global_step=9.000, reduced_train_loss=15.00, tps=28524.0, lr=0.0003, train_step_timing in s=2.290]`Trainer.fit` stopped: `max_steps=10` reached.

Epoch 0:   0%|          | 10/125000 [00:26<91:26:06,  0.38it/s, v_num=0, global_step=9.000, reduced_train_loss=15.00, tps=28524.0, lr=0.0003, train_step_timing in s=2.290]Current pytorch-triton version: 3.2.0+gitb2684bf3b.nvinternal, Required triton version: 3.2.0
Current pytorch-triton version: 3.2.0+gitb2684bf3b.nvinternal, Required triton version: 3.2.0
Current pytorch-triton version: 3.2.0+gitb2684bf3b.nvinternal, Required triton version: 3.2.0

Current pytorch-triton version: 3.2.0+gitb2684bf3b.nvinternal, Required triton version: 3.2.0
Current pytorch-triton version: 3.2.0+gitb2684bf3b.nvinternal, Required triton version: 3.2.0
Current pytorch-triton version: 3.2.0+gitb2684bf3b.nvinternal, Required triton version: 3.2.0
Current pytorch-triton version: 3.2.0+gitb2684bf3b.nvinternal, Required triton version: 3.2.0

Expected behavior

"global_batch_size" and "micro_batch_size" should affect the training step time. For example, we expect longer training step time with larger "global_batch_size".

Environment overview (please complete the following information)

GCP docker env with nvcr.io/nvidia/nemo:25.04

jiuqiant avatar Jun 11 '25 02:06 jiuqiant

Hi @jiuqiant ,

Please opt to use the accumulate_grad_batches parameter passed to the Trainer instead.

The global_batch_size parameter in the datamodule has been deprecated, and instead the micro_batch_size can be used to specify the batch-size of each rank.

Please let me know if you still encounter the issue.

Thank you.

akoumpa avatar Jun 13 '25 00:06 akoumpa

@akoumpa thank you for the clarification on accumulate_grad_batches and global_batch_size. Regarding the 'train_step_timing in the log, what exactly does it measure? Does it refer to the time taken for a forward and backward pass? Is there a separate log entry for the optimizer's step time?

jiuqiant avatar Jun 13 '25 01:06 jiuqiant

This issue is stale because it has been open for 30 days with no activity. Remove stale label or comment or this will be closed in 7 days.

github-actions[bot] avatar Jul 13 '25 02:07 github-actions[bot]

This issue was closed because it has been inactive for 7 days since being marked as stale.

github-actions[bot] avatar Jul 20 '25 02:07 github-actions[bot]