Suggestion: Make `wp.array` class Generic
Hello! I've got a question: Have you considered making wp.array a Generic type, rather than passing the arguments to the constructor in type annotations?
For example, from this:
@wp.kernel
def apply_forces(grid : wp.uint64,
particle_x: wp.array(dtype=wp.vec3),
particle_v: wp.array(dtype=wp.vec3),
particle_f: wp.array(dtype=wp.vec3),
radius: float,
k_contact: float,
k_damp: float,
k_friction: float,
k_mu: float):
...
to this:
@wp.kernel
def apply_forces(grid : wp.uint64,
particle_x: wp.array[wp.vec3],
particle_v: wp.array[wp.vec3],
particle_f: wp.array[wp.vec3],
radius: float,
k_contact: float,
k_damp: float,
k_friction: float,
k_mu: float):
...
This would have the following benefits:
- This would make the annotations "valid" (i.e. no calls inside annotations), so that type checkers could be used in the codebase.
- This would make it possible to enable postponed evaluation of type annotations in the user code (https://peps.python.org/pep-0563/), which doesn't seem to be supported atm (but correct me if I'm wrong).
I assume you're using something like typing.get_type_hints or the __annotations__ dict directly in wp.kernel to extract the type annotations from the function, correct?
With a generic wp.array type, the dtype can still be easily be recovered using typing.get_args on the annotation.
Let me know what you think!
Ah, this is a great suggestion, thanks @lebrice! I will look into what it would mean for the code base.
Hello, indeed this feature would be great, is there any news about it?
Here is a simple example of parsing annotations, as described by @lebrice:
import inspect
import typing
import warp as wp
def my_kernel(arr: wp.array[wp.float32], scalar: float) -> None: ...
sig: inspect.Signature = inspect.signature(my_kernel, eval_str=True)
for name, param in sig.parameters.items():
if isinstance(param.annotation, type):
# handle `float`, `int`, etc.
print(f"{name=}, {param.annotation=}")
else:
# handle `wp.array[wp.float32]`, etc.
origin: type = typing.get_origin(param.annotation)
dtype: type = typing.get_args(param.annotation)[0]
if issubclass(origin, wp.array):
print(f"{name=}, {origin=}, {dtype=}")
It will output:
name='arr', origin=<class 'warp.types.array'>, dtype=<class 'warp.types.float32'>
name='scalar', param.annotation=<class 'float'>