warp icon indicating copy to clipboard operation
warp copied to clipboard

[BUG] Error when using jax_callable with vmap

Open HaoliangWang opened this issue 7 months ago • 2 comments

Bug Description

I'm trying to use Warp's JAX FFI feature, I see that the jax_callback feature supports vmap. However, when I try to use it, it complains about dimensions. Below is a minimal example:

import jax
import jax.numpy as jnp

import warp as wp
from warp.jax_experimental.ffi import jax_callable


@wp.kernel
def scale_kernel(a: wp.array(dtype=float), s: float, output: wp.array(dtype=float)):
    tid = wp.tid()
    output[tid] = a[tid] * s


def in_out_func(
    a: wp.array(dtype=float),  # input only
    c: wp.array(dtype=float),  # output only
):
    wp.launch(scale_kernel, dim=a.size, inputs=[a, 2.0], outputs=[c])


jax_func = jax_callable(in_out_func, vmap_method='broadcast_all', num_outputs=1)

f = jax.jit(jax_func)

a = jnp.ones(100, dtype=jnp.float32).reshape((10, 10))

c = jax.vmap(f, in_axes=0, out_axes=0)(a)
c = f(a)
print(c)

The error I'm getting is:

RuntimeError: Error launching kernel 'scale_kernel', argument 'a' expects an array with 1 dimension(s) but the passed array has 2 dimension(s).

E0721 21:11:26.423976 3931630 pjrt_stream_executor_client.cc:3077] Execution of replica 0 failed: UNKNOWN: FFI callback error: RuntimeError: Error launching kernel 'scale_kernel', argument 'a' expects an array with 1 dimension(s) but the passed array has 2 dimension(s).

System Information

I'm using warp 1.7.0

HaoliangWang avatar Jul 22 '25 01:07 HaoliangWang

vmap for jax_kernel doesn't seem to work either:

I'm taking this example from the documentation, and tried to vmap it to a 2d array

import jax
import jax.numpy as jnp

import warp as wp
from warp.jax_experimental.ffi import jax_kernel

@wp.kernel
def add_kernel(a: wp.array(dtype=int),
               b: wp.array(dtype=int),
               output: wp.array(dtype=int)):
    tid = wp.tid()
    output[tid] = a[tid] + b[tid]

jax_add = jax_kernel(add_kernel)

@jax.jit
def f(a, b):
    return jax_add(a, b)

n = 10
a = jnp.arange(n*n, dtype=jnp.int32).reshape(n, n)
b = jnp.ones(n*n, dtype=jnp.int32).reshape(n, n)
print(jax.vmap(f)(a, b))

This gives me no error but the output is wrong:

[Array([[ 1,  2,  3,  4,  5,  6,  7,  8,  9, 10],
       [ 1,  1,  1,  1,  1,  1,  1,  1,  1,  1],
       [ 1,  1,  1,  1,  1,  1,  1,  1,  1,  1],
       [ 1,  1,  1,  1,  1,  1,  1,  1,  1,  1],
       [ 1,  1,  1,  1,  1,  1,  1,  1,  1,  1],
       [ 1,  1,  1,  1,  1,  1,  1,  1,  1,  1],
       [ 1,  1,  1,  1,  1,  1,  1,  1,  1,  1],
       [ 1,  1,  1,  1,  1,  1,  1,  1,  1,  1],
       [ 1,  1,  1,  1,  1,  1,  1,  1,  1,  1],
       [ 1,  1,  1,  1,  1,  1,  1,  1,  1,  1]], dtype=int32)]

the expected output should be:

[Array([[  1,   2,   3,   4,   5,   6,   7,   8,   9,  10],
       [ 11,  12,  13,  14,  15,  16,  17,  18,  19,  20],
       [ 21,  22,  23,  24,  25,  26,  27,  28,  29,  30],
       [ 31,  32,  33,  34,  35,  36,  37,  38,  39,  40],
       [ 41,  42,  43,  44,  45,  46,  47,  48,  49,  50],
       [ 51,  52,  53,  54,  55,  56,  57,  58,  59,  60],
       [ 61,  62,  63,  64,  65,  66,  67,  68,  69,  70],
       [ 71,  72,  73,  74,  75,  76,  77,  78,  79,  80],
       [ 81,  82,  83,  84,  85,  86,  87,  88,  89,  90],
       [ 91,  92,  93,  94,  95,  96,  97,  98,  99, 100]], dtype=int32)]

HaoliangWang avatar Jul 28 '25 11:07 HaoliangWang

Just wanted to kindly follow up on this—do you have an estimate of when the bug might be fixed? Really appreciate your help! @nvlukasz @shi-eric

HaoliangWang avatar Aug 04 '25 22:08 HaoliangWang