Adjusting "global_batch_size" and "micro_batch_size" has no impact on how long each training step takes when using HFAutoModel.
Describe the bug
The training step time remains constant for Gemma3 HFAutoModel and MockDataModule, regardless of the "global_batch_size" and "micro_batch_size" values set in MockDataModule.
Steps/Code to reproduce bug
The "gemma3_automodel_test.py":
import os
import sys
import argparse
from nemo.collections import llm
from nemo import lightning as nl
from dataclasses import dataclass, asdict
from nemo.collections.llm.gpt.model.llama import *
from nemo.collections.common.tokenizers.huggingface.auto_tokenizer import AutoTokenizer
from nemo.collections.llm.gpt.data import PreTrainingDataModule
from nemo.collections.llm.gpt.data.mock import MockDataModule
from nemo.collections.llm.gpt.data.hf_dataset import HFMockDataModule
from megatron.core.optimizer import OptimizerConfig
from nemo.collections.common.metrics.perf_metrics import FLOPsMeasurementCallback
from nemo.utils.callbacks import NeMoModelCheckpoint
from nemo.utils.exp_manager import TimingCallback
import lightning.pytorch as pl
from nemo.lightning.pytorch.accelerate.transformer_engine import TEConfig
from nemo.lightning.pytorch.callbacks import MemoryProfileCallback
from nemo.lightning.pytorch.callbacks import NsysCallback
from pytorch_lightning.loggers import WandbLogger, TensorBoardLogger
from lightning.pytorch import seed_everything
import torch
import fiddle as fdl
HF_MODEL_NAME = "google/gemma-3-4b-it"
SEQ_LENGTH = 8192
def main():
max_sequence_length = 8192
num_nodes = 1
model = llm.HFAutoModelForCausalLM(
model_name=HF_MODEL_NAME,
model_accelerator=None,
trust_remote_code=True,
use_liger_kernel=True,
)
tokenizer = AutoTokenizer(HF_MODEL_NAME)
data = llm.MockDataModule(
num_train_samples=1000000,
seq_length=max_sequence_length,
global_batch_size=128, # tried 1, 8, 32, 64 ...
micro_batch_size=1, # tried 1, 2, 4, 8 ...
)
strategy = nl.FSDP2Strategy(
data_parallel_size=8 * num_nodes,
tensor_parallel_size=1,
checkpoint_io=model.make_checkpoint_io(adapter_only=True),
offload_policy=None,
)
trainer = nl.Trainer(
accelerator="gpu",
devices=8,
num_nodes=num_nodes,
strategy=strategy,
plugins=nl.MegatronMixedPrecision(precision="bf16-mixed"),
max_epochs=None,
max_steps=20,
log_every_n_steps=1,
limit_val_batches=2,
limit_test_batches=2,
callbacks=[TimingCallback()],
)
opt=fdl.build(llm.adam.pytorch_adam_with_flat_lr(lr=3e-4))
# Setup checkpoint and tensorboard for logger
ckpt = nl.ModelCheckpoint(
save_top_k=1,
save_last=True,
save_optim_on_train_end=False,
filename="{val_loss:.2f}-{step}-{consumed_samples}",
)
tb = TensorBoardLogger(
save_dir="tensorboard",
name="",
)
logger = logger = nl.NeMoLogger(
explicit_log_dir="/logs",
log_global_rank_0_only=True,
update_logger_directory=True,
ckpt=ckpt,
tensorboard=tb,
)
resume = nl.AutoResume(
resume_if_exists=True,
resume_ignore_no_checkpoint=True,
)
llm.pretrain(
model=model,
data=data,
trainer=trainer,
log=logger,
optim=opt,
resume=resume,
)
if __name__ == "__main__":
main()
cmd to run:
docker run --gpus all -it --rm -v /home/test_run:/workspace/test_run --shm-size=8g --ulimit memlock=-1 --ulimit stack=67108864 -e CUDA_DEVICE_MAX_CONNECTIONS=1 nvcr.io/nvidia/nemo:25.04 /bin/bash -c "HF_TOKEN=hf_<your id> python test_run/gemma3_automodel_test.py"
I consistently observe the same training step time in the logs, regardless of the "global_batch_size" and "micro_batch_size" settings. On a node with 8 H100s:
Loading checkpoint shards: 0%| | 0/2 [00:00<?, ?it/s]
| Name | Type | Params | Mode
---------------------------------------------------------------------
0 | model | FSDPGemma3ForConditionalGeneration | 4.3 B | train
---------------------------------------------------------------------
4.3 B Trainable params
0 Non-trainable params
4.3 B Total params
17,200.318Total estimated model params size (MB)
888 Modules in train mode
0 Modules in eval mode
Loading checkpoint shards: 50%|█████ | 1/2 [00:00<00:00, 6.32it/s]
Loading checkpoint shards: 100%|██████████| 2/2 [00:00<00:00, 10.44it/s]
LOCAL_RANK: 7 - CUDA_VISIBLE_DEVICES: [0,1,2,3,4,5,6,7]
Sanity Checking: | | 0/? [00:00<?, ?it/s]
Sanity Checking: 0%| | 0/2 [00:00<?, ?it/s]
Sanity Checking DataLoader 0: 0%| | 0/2 [00:00<?, ?it/s]
Sanity Checking DataLoader 0: 50%|█████ | 1/2 [00:06<00:06, 0.16it/s]
Sanity Checking DataLoader 0: 100%|██████████| 2/2 [00:08<00:00, 0.23it/s]
Training: | | 0/? [00:00<?, ?it/s]
Training: 0%| | 0/125000 [00:00<?, ?it/s]
Epoch 0: 0%| | 0/125000 [00:00<?, ?it/s] Current pytorch-triton version: 3.2.0+gitb2684bf3b.nvinternal, Required triton version: 3.2.0
Epoch 0: 0%| | 1/125000 [00:05<198:49:05, 0.17it/s]
Epoch 0: 0%| | 1/125000 [00:05<198:50:07, 0.17it/s, v_num=0, global_step=0.000, reduced_train_loss=14.00, tps=1.2e+4, lr=0.0003, train_step_timing in s=5.720]
Epoch 0: 0%| | 2/125000 [00:07<138:43:01, 0.25it/s, v_num=0, global_step=0.000, reduced_train_loss=14.00, tps=1.2e+4, lr=0.0003, train_step_timing in s=5.720]
Epoch 0: 0%| | 2/125000 [00:07<138:43:20, 0.25it/s, v_num=0, global_step=1.000, reduced_train_loss=14.90, tps=34084.0, lr=0.0003, train_step_timing in s=2.260]
Epoch 0: 0%| | 3/125000 [00:10<118:47:45, 0.29it/s, v_num=0, global_step=1.000, reduced_train_loss=14.90, tps=34084.0, lr=0.0003, train_step_timing in s=2.260]
Epoch 0: 0%| | 3/125000 [00:10<118:47:58, 0.29it/s, v_num=0, global_step=2.000, reduced_train_loss=14.90, tps=28905.0, lr=0.0003, train_step_timing in s=2.270]
Epoch 0: 0%| | 4/125000 [00:12<108:50:50, 0.32it/s, v_num=0, global_step=2.000, reduced_train_loss=14.90, tps=28905.0, lr=0.0003, train_step_timing in s=2.270]
Epoch 0: 0%| | 4/125000 [00:12<108:51:00, 0.32it/s, v_num=0, global_step=3.000, reduced_train_loss=14.10, tps=28760.0, lr=0.0003, train_step_timing in s=2.270]
Epoch 0: 0%| | 5/125000 [00:14<102:53:18, 0.34it/s, v_num=0, global_step=3.000, reduced_train_loss=14.10, tps=28760.0, lr=0.0003, train_step_timing in s=2.270]
Epoch 0: 0%| | 5/125000 [00:14<102:53:27, 0.34it/s, v_num=0, global_step=4.000, reduced_train_loss=17.70, tps=28734.0, lr=0.0003, train_step_timing in s=2.270]
Epoch 0: 0%| | 6/125000 [00:17<99:02:51, 0.35it/s, v_num=0, global_step=4.000, reduced_train_loss=17.70, tps=28734.0, lr=0.0003, train_step_timing in s=2.270]
Epoch 0: 0%| | 6/125000 [00:17<99:02:58, 0.35it/s, v_num=0, global_step=5.000, reduced_train_loss=14.40, tps=28522.0, lr=0.0003, train_step_timing in s=2.300]
Epoch 0: 0%| | 7/125000 [00:19<96:17:52, 0.36it/s, v_num=0, global_step=5.000, reduced_train_loss=14.40, tps=28522.0, lr=0.0003, train_step_timing in s=2.300]
Epoch 0: 0%| | 7/125000 [00:19<96:17:58, 0.36it/s, v_num=0, global_step=6.000, reduced_train_loss=13.70, tps=28514.0, lr=0.0003, train_step_timing in s=2.300]
Epoch 0: 0%| | 8/125000 [00:21<94:17:27, 0.37it/s, v_num=0, global_step=6.000, reduced_train_loss=13.70, tps=28514.0, lr=0.0003, train_step_timing in s=2.300]
Epoch 0: 0%| | 8/125000 [00:21<94:17:33, 0.37it/s, v_num=0, global_step=7.000, reduced_train_loss=14.50, tps=28361.0, lr=0.0003, train_step_timing in s=2.310]
Epoch 0: 0%| | 9/125000 [00:24<92:43:17, 0.37it/s, v_num=0, global_step=7.000, reduced_train_loss=14.50, tps=28361.0, lr=0.0003, train_step_timing in s=2.310]
Epoch 0: 0%| | 9/125000 [00:24<92:43:21, 0.37it/s, v_num=0, global_step=8.000, reduced_train_loss=14.50, tps=28349.0, lr=0.0003, train_step_timing in s=2.310]
Epoch 0: 0%| | 10/125000 [00:26<91:25:34, 0.38it/s, v_num=0, global_step=8.000, reduced_train_loss=14.50, tps=28349.0, lr=0.0003, train_step_timing in s=2.310]
Epoch 0: 0%| | 10/125000 [00:26<91:25:38, 0.38it/s, v_num=0, global_step=9.000, reduced_train_loss=15.00, tps=28524.0, lr=0.0003, train_step_timing in s=2.290]
Epoch 0: 0%| | 10/125000 [00:26<91:25:57, 0.38it/s, v_num=0, global_step=9.000, reduced_train_loss=15.00, tps=28524.0, lr=0.0003, train_step_timing in s=2.290]`Trainer.fit` stopped: `max_steps=10` reached.
Epoch 0: 0%| | 10/125000 [00:26<91:26:06, 0.38it/s, v_num=0, global_step=9.000, reduced_train_loss=15.00, tps=28524.0, lr=0.0003, train_step_timing in s=2.290]Current pytorch-triton version: 3.2.0+gitb2684bf3b.nvinternal, Required triton version: 3.2.0
Current pytorch-triton version: 3.2.0+gitb2684bf3b.nvinternal, Required triton version: 3.2.0
Current pytorch-triton version: 3.2.0+gitb2684bf3b.nvinternal, Required triton version: 3.2.0
Current pytorch-triton version: 3.2.0+gitb2684bf3b.nvinternal, Required triton version: 3.2.0
Current pytorch-triton version: 3.2.0+gitb2684bf3b.nvinternal, Required triton version: 3.2.0
Current pytorch-triton version: 3.2.0+gitb2684bf3b.nvinternal, Required triton version: 3.2.0
Current pytorch-triton version: 3.2.0+gitb2684bf3b.nvinternal, Required triton version: 3.2.0
Expected behavior
"global_batch_size" and "micro_batch_size" should affect the training step time. For example, we expect longer training step time with larger "global_batch_size".
Environment overview (please complete the following information)
GCP docker env with nvcr.io/nvidia/nemo:25.04
Hi @jiuqiant ,
Please opt to use the accumulate_grad_batches parameter passed to the Trainer instead.
The global_batch_size parameter in the datamodule has been deprecated, and instead the micro_batch_size can be used to specify the batch-size of each rank.
Please let me know if you still encounter the issue.
Thank you.
@akoumpa thank you for the clarification on accumulate_grad_batches and global_batch_size. Regarding the 'train_step_timing in the log, what exactly does it measure? Does it refer to the time taken for a forward and backward pass? Is there a separate log entry for the optimizer's step time?
This issue is stale because it has been open for 30 days with no activity. Remove stale label or comment or this will be closed in 7 days.
This issue was closed because it has been inactive for 7 days since being marked as stale.