warp
warp copied to clipboard
[BUG] Error when using jax_callable with vmap
Bug Description
I'm trying to use Warp's JAX FFI feature, I see that the jax_callback feature supports vmap. However, when I try to use it, it complains about dimensions. Below is a minimal example:
import jax
import jax.numpy as jnp
import warp as wp
from warp.jax_experimental.ffi import jax_callable
@wp.kernel
def scale_kernel(a: wp.array(dtype=float), s: float, output: wp.array(dtype=float)):
tid = wp.tid()
output[tid] = a[tid] * s
def in_out_func(
a: wp.array(dtype=float), # input only
c: wp.array(dtype=float), # output only
):
wp.launch(scale_kernel, dim=a.size, inputs=[a, 2.0], outputs=[c])
jax_func = jax_callable(in_out_func, vmap_method='broadcast_all', num_outputs=1)
f = jax.jit(jax_func)
a = jnp.ones(100, dtype=jnp.float32).reshape((10, 10))
c = jax.vmap(f, in_axes=0, out_axes=0)(a)
c = f(a)
print(c)
The error I'm getting is:
RuntimeError: Error launching kernel 'scale_kernel', argument 'a' expects an array with 1 dimension(s) but the passed array has 2 dimension(s).
E0721 21:11:26.423976 3931630 pjrt_stream_executor_client.cc:3077] Execution of replica 0 failed: UNKNOWN: FFI callback error: RuntimeError: Error launching kernel 'scale_kernel', argument 'a' expects an array with 1 dimension(s) but the passed array has 2 dimension(s).
System Information
I'm using warp 1.7.0
vmap for jax_kernel doesn't seem to work either:
I'm taking this example from the documentation, and tried to vmap it to a 2d array
import jax
import jax.numpy as jnp
import warp as wp
from warp.jax_experimental.ffi import jax_kernel
@wp.kernel
def add_kernel(a: wp.array(dtype=int),
b: wp.array(dtype=int),
output: wp.array(dtype=int)):
tid = wp.tid()
output[tid] = a[tid] + b[tid]
jax_add = jax_kernel(add_kernel)
@jax.jit
def f(a, b):
return jax_add(a, b)
n = 10
a = jnp.arange(n*n, dtype=jnp.int32).reshape(n, n)
b = jnp.ones(n*n, dtype=jnp.int32).reshape(n, n)
print(jax.vmap(f)(a, b))
This gives me no error but the output is wrong:
[Array([[ 1, 2, 3, 4, 5, 6, 7, 8, 9, 10],
[ 1, 1, 1, 1, 1, 1, 1, 1, 1, 1],
[ 1, 1, 1, 1, 1, 1, 1, 1, 1, 1],
[ 1, 1, 1, 1, 1, 1, 1, 1, 1, 1],
[ 1, 1, 1, 1, 1, 1, 1, 1, 1, 1],
[ 1, 1, 1, 1, 1, 1, 1, 1, 1, 1],
[ 1, 1, 1, 1, 1, 1, 1, 1, 1, 1],
[ 1, 1, 1, 1, 1, 1, 1, 1, 1, 1],
[ 1, 1, 1, 1, 1, 1, 1, 1, 1, 1],
[ 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]], dtype=int32)]
the expected output should be:
[Array([[ 1, 2, 3, 4, 5, 6, 7, 8, 9, 10],
[ 11, 12, 13, 14, 15, 16, 17, 18, 19, 20],
[ 21, 22, 23, 24, 25, 26, 27, 28, 29, 30],
[ 31, 32, 33, 34, 35, 36, 37, 38, 39, 40],
[ 41, 42, 43, 44, 45, 46, 47, 48, 49, 50],
[ 51, 52, 53, 54, 55, 56, 57, 58, 59, 60],
[ 61, 62, 63, 64, 65, 66, 67, 68, 69, 70],
[ 71, 72, 73, 74, 75, 76, 77, 78, 79, 80],
[ 81, 82, 83, 84, 85, 86, 87, 88, 89, 90],
[ 91, 92, 93, 94, 95, 96, 97, 98, 99, 100]], dtype=int32)]
Just wanted to kindly follow up on this—do you have an estimate of when the bug might be fixed? Really appreciate your help! @nvlukasz @shi-eric