Annotate function which return a specific module (e.g. `-> Literal[np]`)
Problem
Currently it's not possible to annotate a function which return a specific module, like:
np = load_numpy() # def load_numpy() -> ??:
x = np.array(123) # << I want to have auto-complete & type checking here
-
-> types.ModuleTypedoes not work as too generic (no auto-complete, nor static type checking,...) -
Protocolis not applicable in practice: maintaining a numpy protocol which has 500+ symbols is just not realistic
Use case
Here are some concrete use-case where this feature is needed (also collected from this thread comments):
-
Numpy Enancement Propopsal 37 propose a recipe to have code working with various numpy implementation (
numpy,jax.numpy,tensorflow.numpy):def duckarray_add_random(array): module = np.get_array_module(array) # def get_array_module() -> Literal[np] noise = module.random.randn(*array.shape) # << I want to have auto-complete & type checking here return array + noiseI developed my version at https://github.com/google/etils/tree/main/etils/enp#code-that-works-with-nparray-jnparray-tftensor
-
Lazy imports is a common pattern to only import a module if needed. Like: https://github.com/tensorflow/datasets/blob/76f8591def26afaca16340b06d057553582f6163/tensorflow_datasets/core/lazy_imports_lib.py#L40-L197
beam = lazy_import.apache_beam beam.Pipeline() # << No auto-completion -
From another user comment: Similar issue to to 1. encountered at: https://github.com/data-apis/array-api/issues/267
-
From another user comment:
I encountered a similar issue before. Although not with with the return type but rather version dependent imports, e.g. assign either ast (Py >= 3.8) or typed_ast.ast3 to a common variable.
Proposal
I would like to annotate my function as:
def load_numpy() -> Literal[np]:
Or:
def load_numpy() -> np:
For the lazy-loading case, typing.TYPE_CHECKING pattern could be used:
if typing.TYPE_CHECKING:
import numpy as np
def load_numpy() -> Literal[np]:
import numpy as np
return np
I encountered a similar issue before. Although not with with the return type but rather version dependent imports, e.g. assign either ast (Py >= 3.8) or typed_ast.ast3 to a common variable.
Not sure Literal[np] would be a good solution though. types.ModuleType isn't wrong, it's just not precise enough. Maybe it should be generic?
if typing.TYPE_CHECKING:
import numpy as np
def load_numpy() -> types.ModuleType[np]:
...
That's quite close to the how pyright already handles it at the moment.
import numpy as np
reveal_type(np)
Type of "np" is "Module("numpy")"
If possible, it's best to avoid dynamically loading modules and stick to the normal import statements. That works best with static type checkers.
If you really need to describe a module as a type, one option is to use a protocol class to describe the interface provided by the module in question. PEP 544 specifically allows this. Pyright implements this functionality, but mypy hasn't yet implemented this. That's probably not what you're looking for in this case.
As you pointed out, another option is to use an if TYPE_CHECKING conditional, like this:
if TYPE_CHECKING:
import numpy as np
else
np = load_numpy()
I don't think Literal[np] would make sense here. And using a variable like np in a type expression would violate a bunch of other principles about type annotation expressions.
If you really need to describe a module as a type, one option is to use a protocol class to describe the interface provided by the module in question.
This quickly becomes unfeasible for larger namespaces though. The main numpy namespace for example contains ~550 objects; the duplicating necessary for defining all corresponding protocol method is massive and, at best, highly impractical.
Secondly, as an additional data point: the same issue (no support for literal modules) was previously encountered in https://github.com/data-apis/array-api/issues/267.
It's best to avoid dynamically loading modules and stick to the normal import statements.
It's not always possible, as pointed out in the use-cases. (e.g. NEP 37)
another option is to use an if TYPE_CHECKING conditional, like this:
You solution load np = load_numpy() in the global scope, which doesn't work for the use-cases explained in the top-level comment (lazy imports, NEP 37,...)
Any update on this issue ? Not being able to annotate module prevent to use auto-completion & type checking:
def my_transformation(x, xnp: NpModule): # xnp can be jax.numpy, torch, tf.numpy
return xnp.sum(x) # << No auto-completion nor type-checking here
You can be creative with the TYPE_CHECKING guard, and certainly don't have to do it in the global scope.
Something like this might work as well (untested):
from typing import TYPE_CHECKING
if TYPE_CHECKING:
import numpy as lazymodule
else:
class lazymodule:
np = load_numpy()
@classmethod
def __getattr__(cls, attr):
return getattr(cls.np, attr)
reveal_type(lazymodule.array)