ImportError: cannot import name 'Array' from 'jaxtyping'
import jaxtyping
print(jaxtyping.__version__) # returns 0.2.14
# Import both the annotation and the `jaxtyped` decorator from `jaxtyping`
from jaxtyping import Array, Float32, jaxtyped
returns
ImportError: cannot import name 'Array' from 'jaxtyping' (/home/jovyan/conda/lib/python3.8/site-packages/jaxtyping/__init__.py)
You need to have JAX installed as well.
jaxtyping only has JAX as an optional dependency, to support also being used with PyTorch etc.
Ran into this as well for pytorch. For me the solution as described on https://docs.kidger.site/jaxtyping/api/array/#array was to use torch.Tensor in place of jaxtyping.Array, like so:
from torch import Tensor
from jaxtyping import Float32
def f(x: Float32[Tensor, "dim1 dim2"]) -> Float32[Tensor, "dim1 dim2"]:
return x
@MilesCranmer thanks. @patrick-kidger the jax requirement was relaxed? I don't see it anymore in pyproject.toml.
If so i'll modify my code according to the syntax suggested by Miles
Sorry, missed this question. Yes, jaxtyping no longer depends on JAX. The name is now for historical reasons only! The syntax Miles is using is correct.
When authoring ML runtime agnostic tooling, such as a dataset, what is the correct array type to use? I cannot assume I have neither torch, jax nor tensorflow. I currently assume at least numpy and do the following, but it might not work for other use cases:
from typing import Union, TYPE_CHECKING
from jaxtyping import Float, Bool
if TYPE_CHECKING:
from torch import Tensor
from numpy import ndarray
from jaxtyping import Array as JaxArray
# TODO: tensorflow
Array = Union[Tensor, ndarray, JaxArray]
else:
from numpy import ndarray as Array
Probably something like this:
from typing import Union, TYPE_CHECKING
if TYPE_CHECKING:
from torch import TorchTensor
from numpy import ndarray
from jaxtyping import Array as JaxArray
from tensorflow import TfTensor
Array = Union[TorchTensor, ndarray, JaxArray, TfTensor]
else:
arrays = []
try:
from torch import Tensor as TorchTensor
except Exception:
pass
else:
arrays.append(TorchTensor)
try:
from numpy import ndarray
except Exception:
pass
else:
arrays.append(ndarray)
try:
from jaxtyping import Array as JaxArray
except Exception:
pass
else:
arrays.append(JaxArray)
try:
from tensorflow import Tensor as TfTensor
except Exception:
pass
else:
arrays.append(TfTensor)
Array = Union[tuple(arrays)]
Neat! I'd go for except (ModuleNotFoundError, ImportError): :wink:
And it doesn't exactly roll of the tongue. Any chance this could be added to jaxtyping?
Actually, the more general Exception is deliberate. There are cases when try-importing a module can result in other issues too, c.f. https://github.com/google/jaxtyping/blob/7a84b27da9e57c425ce4e6333121c3cdf2e07302/jaxtyping/_array_types.py#L33-L39
As for adding the above to jaxtyping. jaxtyping tries to essentially be backend-agnostic. In particular, I don't think I'd want to hardcode that it'll look for specifically torch+numpy+tensorflow+jax and nothing else. As such I think something like this is out-of-scope for jaxtyping, I'm afraid.