jaxtyping icon indicating copy to clipboard operation
jaxtyping copied to clipboard

ImportError: cannot import name 'Array' from 'jaxtyping'

Open danbider opened this issue 2 years ago • 8 comments

import jaxtyping
print(jaxtyping.__version__) # returns 0.2.14
# Import both the annotation and the `jaxtyped` decorator from `jaxtyping`
from jaxtyping import Array, Float32, jaxtyped

returns

ImportError: cannot import name 'Array' from 'jaxtyping' (/home/jovyan/conda/lib/python3.8/site-packages/jaxtyping/__init__.py)

danbider avatar Mar 20 '23 15:03 danbider

You need to have JAX installed as well.

jaxtyping only has JAX as an optional dependency, to support also being used with PyTorch etc.

patrick-kidger avatar Mar 20 '23 16:03 patrick-kidger

Ran into this as well for pytorch. For me the solution as described on https://docs.kidger.site/jaxtyping/api/array/#array was to use torch.Tensor in place of jaxtyping.Array, like so:

from torch import Tensor
from jaxtyping import Float32

def f(x: Float32[Tensor, "dim1 dim2"]) -> Float32[Tensor, "dim1 dim2"]:
    return x

MilesCranmer avatar Jun 20 '23 13:06 MilesCranmer

@MilesCranmer thanks. @patrick-kidger the jax requirement was relaxed? I don't see it anymore in pyproject.toml. If so i'll modify my code according to the syntax suggested by Miles

danbider avatar Jun 20 '23 14:06 danbider

Sorry, missed this question. Yes, jaxtyping no longer depends on JAX. The name is now for historical reasons only! The syntax Miles is using is correct.

patrick-kidger avatar Aug 21 '23 17:08 patrick-kidger

When authoring ML runtime agnostic tooling, such as a dataset, what is the correct array type to use? I cannot assume I have neither torch, jax nor tensorflow. I currently assume at least numpy and do the following, but it might not work for other use cases:

from typing import Union, TYPE_CHECKING
from jaxtyping import Float, Bool
if TYPE_CHECKING:
    from torch import Tensor
    from numpy import ndarray
    from jaxtyping import Array as JaxArray
    # TODO: tensorflow
    Array = Union[Tensor, ndarray, JaxArray]
else:
    from numpy import ndarray as Array

pbsds avatar Oct 16 '23 16:10 pbsds

Probably something like this:

from typing import Union, TYPE_CHECKING
if TYPE_CHECKING:
    from torch import TorchTensor
    from numpy import ndarray
    from jaxtyping import Array as JaxArray
    from tensorflow import TfTensor
    Array = Union[TorchTensor, ndarray, JaxArray, TfTensor]
else:
    arrays = []
    try:
        from torch import Tensor as TorchTensor
    except Exception:
        pass
    else:
        arrays.append(TorchTensor)
    try:
        from numpy import ndarray
    except Exception:
        pass
    else:
        arrays.append(ndarray)
    try:
        from jaxtyping import Array as JaxArray
    except Exception:
        pass
    else:
        arrays.append(JaxArray)
    try:
        from tensorflow import Tensor as TfTensor
    except Exception:
        pass
    else:
        arrays.append(TfTensor)
    Array = Union[tuple(arrays)]

patrick-kidger avatar Oct 16 '23 18:10 patrick-kidger

Neat! I'd go for except (ModuleNotFoundError, ImportError): :wink:

And it doesn't exactly roll of the tongue. Any chance this could be added to jaxtyping?

pbsds avatar Oct 16 '23 20:10 pbsds

Actually, the more general Exception is deliberate. There are cases when try-importing a module can result in other issues too, c.f. https://github.com/google/jaxtyping/blob/7a84b27da9e57c425ce4e6333121c3cdf2e07302/jaxtyping/_array_types.py#L33-L39

As for adding the above to jaxtyping. jaxtyping tries to essentially be backend-agnostic. In particular, I don't think I'd want to hardcode that it'll look for specifically torch+numpy+tensorflow+jax and nothing else. As such I think something like this is out-of-scope for jaxtyping, I'm afraid.

patrick-kidger avatar Oct 16 '23 20:10 patrick-kidger