Feature Request: Automatic Mixed Precision
Hi all!
I'd like to ask whether there are any plans to eventually support automatic mixed precision like PyTorch and TensorFlow.
In PyToch, all you gotta do is wrap your training loop with torch.autocast():
# Creates model and optimizer in default precision
model = Net().cuda()
optimizer = optim.SGD(model.parameters(), ...)
for input, target in data:
optimizer.zero_grad()
# Enables autocasting for the forward pass (model + loss)
with torch.autocast(device_type="cuda"):
output = model(input)
loss = loss_fn(output, target)
# Exits the context manager before backward()
loss.backward()
optimizer.step()
In TensorFlow, you simply define a policy:
mixed_precision.set_global_policy('mixed_float16')
As far as I can tell, Flax is missing a similar mechanism. If this statement is correct, how is one expected to train in mixed precision?
Thanks in advance.
@Artoriuz thanks for asking! There were already discussions previously about AMP in Flax, for example: https://github.com/google/flax/discussions/2027 (which mentions https://github.com/google-deepmind/jmp)
There is no amp.autocast context manager in Flax similar PyTorch one. What I have seen is explicit conversion of dtypes, for example in MaxText RMSNorm layer:
https://github.com/AI-Hypercomputer/maxtext/blob/7070e8eecbea8951c8e5281219ce797c8df1441f/MaxText/layers/normalizations.py#L30-L70
Linear, Convolution etc layers can similarly specify the precision and various dtypes:
- https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/nn/linear.html#flax.nnx.Linear
(there is however a feature request to expose
preferred_element_typearg in https://github.com/google/flax/issues/4890).
There are also examples how to perform low-precision training:
- https://github.com/google/flax/blob/f73aea5cc605d9e4530132ca3569b04721942f36/examples/gemma/train.py#L124
- https://github.com/google/flax/blob/main/examples/lm1b_nnx/train.py#L433
Hope this helps