diffrax icon indicating copy to clipboard operation
diffrax copied to clipboard

Vmapping over callables / pytest cache_clear() error

Open thibmonsel opened this issue 3 years ago • 3 comments

Hi @patrick-kidger,

vmap Callables

I'd like to use a vmap'd array of Callable functions inside a dummy_diffeq that would also be vmap'd

A MWE would be something of the sort


import equinox as eqx
import jax.numpy as jnp

def vector_field(t, y, args):
    return -y 

@eqx.filter_vmap(args=(0,))
def history_fn(a):
    return lambda t : jnp.array([a * t, a * t])

vmap_history_fn = history_fn(jnp.arange(10))

def dummy_diffeq(t, y0):
    return vector_field(t, y0(t), None)

@eqx.filter_vmap(args=(None, 0))
def dummy_call(t, y0):
    sol = dummy_diffeq(t, y0)
    return sol

aux = dummy_call(1.0, vmap_history_fn)

The error prompted is

 aux = dummy_call(ts, vmap_history_fn)
  File "/home/monsel/miniconda3/envs/dde/lib/python3.9/site-packages/equinox/vmap_pmap.py", line 187, in __call__
    vmapd, static = jax.vmap(
ValueError: vmap was requested to map its argument along axis 0, which implies that its rank should be at least 1, but is only 0 (its shape is ())

Unrelated pytest error

When I try to run the tests with pytest , I get errors

test/conftest.py:37: TypeError
===================================================================================================== short test summary info ======================================================================================================
ERROR test/test_adaptive_stepsize_controller.py::test_step_ts - TypeError: cache_clear(): incompatible function arguments. The following argument types are supported:
ERROR test/test_adaptive_stepsize_controller.py::test_jump_ts - TypeError: cache_clear(): incompatible function arguments. The following argument types are supported:
ERROR test/test_adjoint.py::test_no_adjoint - TypeError: cache_clear(): incompatible function arguments. The following argument types are supported:
ERROR test/test_adjoint.py::test_backsolve - TypeError: cache_clear(): incompatible function arguments. The following argument types are supported:
ERROR test/test_adjoint.py::test_adjoint_seminorm - TypeError: cache_clear(): incompatible function arguments. The following argument types are supported:
ERROR test/test_adjoint.py::test_closure_errors - TypeError: cache_clear(): incompatible function arguments. The following argument types are supported:
ERROR test/test_adjoint.py::test_closure_fixed - TypeError: cache_clear(): incompatible function arguments. The following argument types are supported:
ERROR test/test_adjoint.py::test_implicit - TypeError: cache_clear(): incompatible function arguments. The following argument types are supported:
ERROR test/test_bounded_while_loop.py::test_functional_no_vmap_no_inplace - TypeError: cache_clear(): incompatible function arguments. The following argument types are supported:
ERROR test/test_bounded_while_loop.py::test_functional_no_vmap_inplace - TypeError: cache_clear(): incompatible function arguments. The following argument types are supported:
ERROR test/test_bounded_while_loop.py::test_functional_vmap_no_inplace - TypeError: cache_clear(): incompatible function arguments. The following argument types are supported:

It seems that all unitary test call upon the clear_caches() when the error comes from

    @pytest.fixture(autouse=True)
    def clear_caches():
        process = psutil.Process()
        if process.memory_info().vms > 4 * 2**30:  # >4GB memory usage
            for module_name, module in sys.modules.items():
                if module_name.startswith("jax"):
                    for obj_name in dir(module):
                        obj = getattr(module, obj_name)
                        if hasattr(obj, "cache_clear"):
>                           obj.cache_clear()
E                           TypeError: cache_clear(): incompatible function arguments. The following argument types are supported:
E                               1. (self: jaxlib.xla_extension.WeakrefLRUCache) -> None
E                           
E                           Invoked with:

test/conftest.py:37: TypeError

If I only run test_interpolation.py for example I get :

pytest  test_interpolation.py 
======================================================================================================= test session starts ========================================================================================================
platform linux -- Python 3.9.13, pytest-7.2.0, pluggy-1.0.0
rootdir: /home/monsel/Desktop/dev/diffrax
plugins: jaxtyping-0.2.7, typeguard-2.13.3
collected 1 item                                                                                                                                                                                                                   

test_interpolation.py E                                                                                                                                                                                                      [100%]

============================================================================================================== ERRORS ==============================================================================================================
________________________________________________________________________________________________ ERROR at setup of test_derivative _________________________________________________________________________________________________

    @pytest.fixture(autouse=True)
    def clear_caches():
        process = psutil.Process()
        if process.memory_info().vms > 4 * 2**30:  # >4GB memory usage
            for module_name, module in sys.modules.items():
                if module_name.startswith("jax"):
                    for obj_name in dir(module):
                        obj = getattr(module, obj_name)
                        if hasattr(obj, "cache_clear"):
>                           obj.cache_clear()
E                           TypeError: cache_clear(): incompatible function arguments. The following argument types are supported:
E                               1. (self: jaxlib.xla_extension.WeakrefLRUCache) -> None
E                           
E                           Invoked with:

conftest.py:37: TypeError
===================================================================================================== short test summary info ======================================================================================================
ERROR test_interpolation.py::test_derivative - TypeError: cache_clear(): incompatible function arguments. The following argument types are supported:
========================================================================================================= 1 error in 0.18s =========================================================================================================

thibmonsel avatar Jan 12 '23 12:01 thibmonsel

Sorry, I'm afraid what you're describing isn't clear. Are you describing two different errors here?

For the first one -- what's a MWE, and what's the error message you get?

For the second one -- that looks like some potential issue with the hacky way we clear memory in our tests. It's possible that a JAX update has broken this. What type was it actually invoked with?

patrick-kidger avatar Jan 12 '23 13:01 patrick-kidger

Sorry about that, lmk if this is clearer now.

For the type it says :

TypeError: cache_clear(): incompatible function arguments. The following argument types are supported:
E                               1. (self: jaxlib.xla_extension.WeakrefLRUCache) -> None
E                           
E                           Invoked with:

thibmonsel avatar Jan 12 '23 13:01 thibmonsel

Regarding the vmap issue: this is expected. vmap_history_fn is a function, so the only way to pass it into a vmap'd function is as something that's broadcast. (Here you have 0 instead.)

Also, note that

@eqx.filter_vmap(args=(0,))
def history_fn(a):
    return lambda t : jnp.array([a * t, a * t])

is not allowed. The a is being captured via closure and then "smuggled out" of the vmap region.


Regrading the cache_clear error: I'm not able to reproduce this. Using JAx 0.4.2 and jaxlib 0.4.1 then I find that the tests run as normal.

patrick-kidger avatar Jan 19 '23 21:01 patrick-kidger