[BUG] wrong loss scaling when context parallel is on
Describe the bug
Hi, I think there is a bug when context parallel is on and we can discuss it. https://github.com/NVIDIA/Megatron-LM/blob/0bc3547702464501feefeb5523b7a17e591b21fa/pretrain_gpt.py#L148
From this issue,i know the result is same for dp2cp4 and dp8 with the same global batch_size.
But the code logic is different bewteen above issue and current code logic. In above issue logic, the loss scaling with cp_size and grad_data scaling with the world_size from get_data_parallel_group(with_context_parallel=True) In current code logic, the loss scaling with cp_size, but grad_data scaling with the world_size from get_data_parallel_group()
Two logic have different grad_data. (print the grad_data after allreduce it)
To Reproduce dp2cp4 and dp8 with same parameter can reproduce the result
Proposed fix remove the loss scaling with cp_size in loss_func
@xrennvidia Can you please help to answer this question? Thank you!
Hi @zhaoyinglia
Thanks for reaching out. Could you please point me to the code of the following?
In current code logic, the loss scaling with cp_size, but grad_data scaling with the world_size from get_data_parallel_group()
Weight parameters are replicated across the combined CP+DP group, so we are supposed to do weight gradients all-reduce or reduce-scatter across combined CP+DP group. If we indeed have above grad_data scaling change, we need to fix it.
The issue has been resolved at https://github.com/NVIDIA/Megatron-LM/commit/3bdcbbbe5d2a455a75e28969be7250cd4bd27bae. Thank you!