optimi icon indicating copy to clipboard operation
optimi copied to clipboard

pure fp16 ?

Open vince62s opened this issue 11 months ago • 3 comments

Hi, First many thanks for the package very usefull. It works great with pure bf16 + kahan. With fp16 I can't find the set of good settings to get stability. I tried eps=1e-6 / weight_decay=1e-5 but it's unstable.

EDIT: I think it can't work without a dynamic grad scaler

Cheers, V.

vince62s avatar Mar 08 '25 13:03 vince62s

Pure-FP16 training being less stable than pure-BF16 matches what I would expect. That's why optimi's documentation only mentions training in pure-BF16.

I didn't restrict anyone from using it since it's trivial to support (param.dtype in [torch.float16, torch.bfloat16] instead of param.dtype == torch.blfoat16) and maybe someone would find it useful.

warner-benjamin avatar Mar 09 '25 19:03 warner-benjamin

I am also trying to use optimi in conjunction with GradScaler. However, when apply step() I am getting this error:

self.optim.step() File "/mnt/InternalCrucial4/nlp/eole/eole/utils/optimizers.py", line 414, in step self.scaler.unscale(self.optimizer) File "/home/vincent/miniconda3/envs/pt2.5/lib/python3.11/site-packages/torch/amp/grad_scaler.py", line 342, in unscale optimizer_state["found_inf_per_device"] = self.unscale_grads( ^^^^^^^^^^^^^^^^^^^^^ File "/home/vincent/miniconda3/envs/pt2.5/lib/python3.11/site-packages/torch/amp/grad_scaler.py", line 283, in unscale_grads torch.amp_foreach_non_finite_check_and_unscale( RuntimeError: "_amp_foreach_non_finite_check_and_unscale_cuda" not implemented for 'BFloat16'

I tried to instanciate AdamW with foreach=False, but same issue. Using the same path with AMP/bf16 (with torch.optim) works fine

EDIT: Nevermind, pytorch won't work with lower precision grads scaling/unscaling this needs to be done separately

vince62s avatar Mar 10 '25 09:03 vince62s

okay. see my comment here in pytorch upstream: https://github.com/pytorch/pytorch/issues/127176#issuecomment-2710153787

Then I digged a bit and found this: https://github.com/facebookresearch/fairseq2/blob/main/src/fairseq2/optim/_dynamic_loss_scaler.py#L251-L263

I think this is exactly what we need to add a scaling/unscaling function in pure fp16/bf16. For pure bf16 it would prevent underflowing on top of kahan comp.

I don't know if this is the same for you but pure bf16 remains worse than bf16 amp even though keeping all LayerNorm in FP32.

vince62s avatar Mar 10 '25 17:03 vince62s