quax icon indicating copy to clipboard operation
quax copied to clipboard

Using quax with jax.jvp

Open nardi opened this issue 10 months ago • 1 comments

This is more of a question, but I didn't see a discussions tab. Feel free to move/close if not appropriate :)

I'm trying to pass some quax values through a JVP computation. In principle, it seems to work fine, but I'm not able to find a way to take the JVP of a subset of arguments only. Here is a small example.

Suppose we have a simple array wrapper class:

class ArrayWrapper(quax.ArrayValue):
    array: jax.Array

    def aval(self):
        return self.array.aval

    def materialise(self):
        raise NotImplementedError

And I've defined dot_general on it:

@quax.register(jax.lax.dot_general_p)
def dot_general(a: ArrayWrapper, b: ArrayWrapper, **params):
    return ArrayWrapper(jax.lax.dot_general_p.bind(a.array, b.array, **params))

@quax.register(jax.lax.dot_general_p)
def dot_general(a, b: ArrayWrapper, **params):
    return ArrayWrapper(jax.lax.dot_general_p.bind(a, b.array, **params))

@quax.register(jax.lax.dot_general_p)
def dot_general(a: ArrayWrapper, b, **params):
    return ArrayWrapper(jax.lax.dot_general_p.bind(a.array, b, **params))

Now, we can evaluate jnp.dot:

x = ArrayWrapper(1 + jnp.arange(3, dtype=float))
quax_dot = quax.quaxify(jnp.dot)

quax_dot(x, x)
# Output: ArrayWrapper(14.0)

Then, I would like to take the JVP of jnp.dot. I've found this is possible without any further changes by wrapping the jax.jvp call in quaxify:

quax.quaxify(lambda p, t: jax.jvp(jnp.dot, p, t))((x, x), (x, x))
# Output: (ArrayWrapper(14.0), ArrayWrapper(28.0))

Now, suppose I only want the JVP with respect to the second argument. Normally, this is possible by a simple partial application:

(lambda p, t: jax.jvp(partial(jnp.dot, x.array), p, t))((x.array,), (x.array,))
# Output: (14.0, 14.0)

However, I can't seem to find a way to accomplish this while preserving my ArrayWrapper types. The naive partial application:

quax.quaxify(lambda p, t: jax.jvp(partial(jnp.dot, x), p, t))((x,), (x,))

doesn't work of course, since the partially applied x does not pass through quaxify. However, using a nested quaxify also doesn't work correctly:

quax.quaxify(lambda p, t: jax.jvp(partial(quax.quaxify(jnp.dot), x), p, t))((x,), (x,))
# Output: (ArrayWrapper(ArrayWrapper(14.0)), ArrayWrapper(ArrayWrapper(14.0)))

I've tried a number of different ways of nesting quaxify and using filter_spec but I can't seem to find a way to get this to work as expected. To be precise, what I would expect as result is (ArrayWrapper(14.0), ArrayWrapper(14.0)). Of course, to obtain this I could simply set the tangents of the other arguments to zero:

quax.quaxify(lambda p, t: jax.jvp(jnp.dot, p, t))((x, x), (ArrayWrapper(jnp.zeros_like(x.array)), x))
# Output: (ArrayWrapper(14.0), ArrayWrapper(14.0))

But this seems like more of a workaround, and is not possible if some of the arguments are not differentiable.

Am I missing some proper way to do this or am I asking for something that for a good reason is not possible?

nardi avatar Jun 09 '25 17:06 nardi

Of course, right after asking I think about it a bit more and answer my own question :)

In the naive case:

quax.quaxify(lambda p, t: jax.jvp(partial(jnp.dot, x), p, t))((x,), (x,))

the problem is that partial(jnp.dot, x) happens outside the quax boundary. So it is equivalent to:

x_dot = partial(jnp.dot, x)

quax.quaxify(lambda p, t: jax.jvp(x_dot, p, t))((x,), (x,))

So the partially applied value x is basically global state!

I thought of two ways to fix this: either I can partially apply x inside the quaxified function:

quax.quaxify(lambda x, p, t: jax.jvp(partial(jnp.dot, x), p, t))(x, (x,), (x,))

or I can pass the partially applied function through (for this I use JAX's Partial):

x_dot = Partial(jnp.dot, x)

quax.quaxify(lambda x_dot, p, t: jax.jvp(x_dot, p, t))(x_dot, (x,), (x,))
# or just
quax.quaxify(jax.jvp)(x_dot, (x,), (x,))

So in the end the simplest thing to do (just quaxify jax.jvp as is) was also the right thing to do :) I guess I got distracted because normally in JAX it is not possible to pass callables as arguments unless wrapped in Partial, so usually it is simpler to just refer to the function globally since it will be traced out anyway. But this is a situation in which it does matter. Maybe something with partial would be a nice example for the docs though, since this is quite a sommon pattern when writing JAX code.

nardi avatar Jun 09 '25 18:06 nardi