Add fp8 types exposed in jax.numpy.
Thank you for the contribution! These all look good to me. Do you have a link to any documentation on these?
Yes. They are defined in https://github.com/jax-ml/jax/blob/main/jax/_src/dtypes.py#L92-L97. I have updated the patch to include a link.
Hi @patrick-kidger, can we merge this change into main?
Yes. They are defined in https://github.com/jax-ml/jax/blob/main/jax/_src/dtypes.py#L92-L97. I have updated the patch to include a link.
Is there any proper documentation, though? Not just their existence in the source code. I'm a little hesitant to expand our own public API to include undocumented features.
(Advanced users such as yourself can continue to subclass AbstractDtype to obtain jaxtyping annotations for these dtypes anyway, after all.)
Hi @patrick-kidger, can we merge this change into main?
As with many open-source projects this is a volunteer effort that happens primarily in my evenings and weekends. Please have a little patience.
Unfortunately I cannot find any public document about fp8. https://github.com/jax-ml/jax/commit/d203926c16a98cf87e88bf090098e6372267b8b2 is the first commit adding fp8 support to JAX, and there's no additional info attached.
Kindly reminder on this PR.
Since these dtypes are not yet public in JAX then I don't think we should make them public either, I'm afraid.
That might change in the future, though :)
Hi Patrick – these are public symbols in JAX, so it should be safe to add them here. Thanks!
What greater confirmation could we ask for :) Thanks Jake -- in that case, merged!