`jnp.frexp` not matching Numpy on sub-normals
Description
The function jnp.frexp does not match Numpy result on sub-normals (I would personally consider the Numpy result to be the more accurate one).
import numpy as np
import jax.numpy as jnp
v = np.finfo(np.float16).smallest_subnormal
np.frexp(v) # returns (0.5, -23)
jnp.frexp(v) # returns (Array(0.5005, dtype=float16), Array(-14, dtype=int32))
A similar issue exists with FP32 too.
What jax/jaxlib version are you using?
jax 0.4.13, jaxlib 0.4.13
Which accelerator(s) are you using?
CPU
Additional system info?
Python 3.8
NVIDIA GPU info
No response
Thanks for the report – I think this is expected, because XLA flushes subnormal values during operations, because hardware like TPU and some GPU devices do not support them
I had a quick check yesterday: this piece of code runs fine and returns the proper result on recent Nvidia GPUs (at least ML ones):
val_f16 = np.finfo(np.float16).smallest_subnormal
val_f32 = np.finfo(np.float32).smallest_subnormal
@jax.jit
def fn(v):
return v * 2
out_f16 = fn(val_f16)
out_f32 = fn(val_f32)
print(out_f16 == (val_f16 * 2), out_f16, out_f16, out_f16.dtype)
print(out_f32 == (val_f32 * 2), out_f32, out_f32, out_f32.dtype)
On TPU, it indeed flushes to zero.
It's not necessarily an easily decision, but my take would be that frexp is just a mantissa + exponent split operation, i.e. closer to a bitwise masking manipulation than an arithmetic op (add, mul, ...). Hence I would expect that the result should the same accross platforms.