ModernBERT icon indicating copy to clipboard operation
ModernBERT copied to clipboard

Bugfix: batch_size_warmup_scheduler was taking too long

Open onurgu opened this issue 11 months ago • 1 comments

BatchSizeWarmupScheduler was taking too long or was impossible for real world max_batch_size values

When trying to use the training script like the following:

conda run -n bert24 composer main.py yamls/modernbert/modernbert-base-pretrain.yaml

the script was not giving any output for a long long while. So I started to read the code. I saw that the code was using sum(range(x, y)) idiom to summing the values along a range, this was inefficient for large y, especially impossible when y=50B or something.

Changes

Simplify BatchSizeWarmupScheduler Implementation

Summary

This PR simplifies the batch size warmup scheduling logic by replacing the step-based threshold calculation with a more straightforward token-based approach. The new implementation provides a more intuitive and mathematically precise way to handle batch size warmup during training.

Changes

  • Replaced _calculate_step_thresholds() with _calculate_tokens_per_batch_size()
  • Changed the scheduler to use token counts instead of steps for determining batch sizes
  • Simplified the batch size calculation using a direct mathematical formula
  • Updated method signatures to better reflect the token-based approach (current_stepcurrent_token_count)

Technical Details

The new implementation:

  1. Calculates total batch sizes using the arithmetic sequence sum formula: (n(a₁ + aₙ))/2
  2. Determines tokens per batch size unit by dividing warmup tokens by total batch sizes
  3. Uses integer division to determine how many batch size increments to apply

Benefits

  • More precise control over batch size progression
  • Simpler, more maintainable code with fewer loops and conditionals
  • Direct relationship between token count and batch size
  • Reduced memory footprint by eliminating the need to store threshold arrays

Discussions

If any, please include references to the relevant issues/previous PR/discord discussions around these changes.

Tests

  • [ ] Is the new feature tested? (Not always necessary for all changes -- just adding to the checklist to keep track)
  • [ ] Have you ran all the tests?
  • [ ] Do the tests all pass?
  • [ ] If not, have you included an explanation of which tests this PR breaks and/or why (below this checklisT)

onurgu avatar Feb 24 '25 21:02 onurgu

I have a question regarding your statement that using the 'sum(range(x, y))' idiom to sum values in a range is inefficient for large y – to the point of being impractical when y is around 50B, for example.

My understanding is that x and y are derived from batch size variables and are not related to the number of tokens. Could you clarify why you consider a scenario where y equals 50B?

jihobak avatar Mar 19 '25 02:03 jihobak