Bugfix: batch_size_warmup_scheduler was taking too long
BatchSizeWarmupScheduler was taking too long or was impossible for real world max_batch_size values
When trying to use the training script like the following:
conda run -n bert24 composer main.py yamls/modernbert/modernbert-base-pretrain.yaml
the script was not giving any output for a long long while. So I started to read the code. I saw that the code was using sum(range(x, y)) idiom to summing the values along a range, this was inefficient for large y, especially impossible when y=50B or something.
Changes
Simplify BatchSizeWarmupScheduler Implementation
Summary
This PR simplifies the batch size warmup scheduling logic by replacing the step-based threshold calculation with a more straightforward token-based approach. The new implementation provides a more intuitive and mathematically precise way to handle batch size warmup during training.
Changes
- Replaced
_calculate_step_thresholds()with_calculate_tokens_per_batch_size() - Changed the scheduler to use token counts instead of steps for determining batch sizes
- Simplified the batch size calculation using a direct mathematical formula
- Updated method signatures to better reflect the token-based approach (
current_step→current_token_count)
Technical Details
The new implementation:
- Calculates total batch sizes using the arithmetic sequence sum formula:
(n(a₁ + aₙ))/2 - Determines tokens per batch size unit by dividing warmup tokens by total batch sizes
- Uses integer division to determine how many batch size increments to apply
Benefits
- More precise control over batch size progression
- Simpler, more maintainable code with fewer loops and conditionals
- Direct relationship between token count and batch size
- Reduced memory footprint by eliminating the need to store threshold arrays
Discussions
If any, please include references to the relevant issues/previous PR/discord discussions around these changes.
Tests
- [ ] Is the new feature tested? (Not always necessary for all changes -- just adding to the checklist to keep track)
- [ ] Have you ran all the tests?
- [ ] Do the tests all pass?
- [ ] If not, have you included an explanation of which tests this PR breaks and/or why (below this checklisT)
I have a question regarding your statement that using the 'sum(range(x, y))' idiom to sum values in a range is inefficient for large y – to the point of being impractical when y is around 50B, for example.
My understanding is that x and y are derived from batch size variables and are not related to the number of tokens. Could you clarify why you consider a scenario where y equals 50B?