Vmapping over callables / pytest cache_clear() error
Hi @patrick-kidger,
vmap Callables
I'd like to use a vmap'd array of Callable functions inside a dummy_diffeq that would also be vmap'd
A MWE would be something of the sort
import equinox as eqx
import jax.numpy as jnp
def vector_field(t, y, args):
return -y
@eqx.filter_vmap(args=(0,))
def history_fn(a):
return lambda t : jnp.array([a * t, a * t])
vmap_history_fn = history_fn(jnp.arange(10))
def dummy_diffeq(t, y0):
return vector_field(t, y0(t), None)
@eqx.filter_vmap(args=(None, 0))
def dummy_call(t, y0):
sol = dummy_diffeq(t, y0)
return sol
aux = dummy_call(1.0, vmap_history_fn)
The error prompted is
aux = dummy_call(ts, vmap_history_fn)
File "/home/monsel/miniconda3/envs/dde/lib/python3.9/site-packages/equinox/vmap_pmap.py", line 187, in __call__
vmapd, static = jax.vmap(
ValueError: vmap was requested to map its argument along axis 0, which implies that its rank should be at least 1, but is only 0 (its shape is ())
Unrelated pytest error
When I try to run the tests with pytest , I get errors
test/conftest.py:37: TypeError
===================================================================================================== short test summary info ======================================================================================================
ERROR test/test_adaptive_stepsize_controller.py::test_step_ts - TypeError: cache_clear(): incompatible function arguments. The following argument types are supported:
ERROR test/test_adaptive_stepsize_controller.py::test_jump_ts - TypeError: cache_clear(): incompatible function arguments. The following argument types are supported:
ERROR test/test_adjoint.py::test_no_adjoint - TypeError: cache_clear(): incompatible function arguments. The following argument types are supported:
ERROR test/test_adjoint.py::test_backsolve - TypeError: cache_clear(): incompatible function arguments. The following argument types are supported:
ERROR test/test_adjoint.py::test_adjoint_seminorm - TypeError: cache_clear(): incompatible function arguments. The following argument types are supported:
ERROR test/test_adjoint.py::test_closure_errors - TypeError: cache_clear(): incompatible function arguments. The following argument types are supported:
ERROR test/test_adjoint.py::test_closure_fixed - TypeError: cache_clear(): incompatible function arguments. The following argument types are supported:
ERROR test/test_adjoint.py::test_implicit - TypeError: cache_clear(): incompatible function arguments. The following argument types are supported:
ERROR test/test_bounded_while_loop.py::test_functional_no_vmap_no_inplace - TypeError: cache_clear(): incompatible function arguments. The following argument types are supported:
ERROR test/test_bounded_while_loop.py::test_functional_no_vmap_inplace - TypeError: cache_clear(): incompatible function arguments. The following argument types are supported:
ERROR test/test_bounded_while_loop.py::test_functional_vmap_no_inplace - TypeError: cache_clear(): incompatible function arguments. The following argument types are supported:
It seems that all unitary test call upon the clear_caches() when the error comes from
@pytest.fixture(autouse=True)
def clear_caches():
process = psutil.Process()
if process.memory_info().vms > 4 * 2**30: # >4GB memory usage
for module_name, module in sys.modules.items():
if module_name.startswith("jax"):
for obj_name in dir(module):
obj = getattr(module, obj_name)
if hasattr(obj, "cache_clear"):
> obj.cache_clear()
E TypeError: cache_clear(): incompatible function arguments. The following argument types are supported:
E 1. (self: jaxlib.xla_extension.WeakrefLRUCache) -> None
E
E Invoked with:
test/conftest.py:37: TypeError
If I only run test_interpolation.py for example I get :
pytest test_interpolation.py
======================================================================================================= test session starts ========================================================================================================
platform linux -- Python 3.9.13, pytest-7.2.0, pluggy-1.0.0
rootdir: /home/monsel/Desktop/dev/diffrax
plugins: jaxtyping-0.2.7, typeguard-2.13.3
collected 1 item
test_interpolation.py E [100%]
============================================================================================================== ERRORS ==============================================================================================================
________________________________________________________________________________________________ ERROR at setup of test_derivative _________________________________________________________________________________________________
@pytest.fixture(autouse=True)
def clear_caches():
process = psutil.Process()
if process.memory_info().vms > 4 * 2**30: # >4GB memory usage
for module_name, module in sys.modules.items():
if module_name.startswith("jax"):
for obj_name in dir(module):
obj = getattr(module, obj_name)
if hasattr(obj, "cache_clear"):
> obj.cache_clear()
E TypeError: cache_clear(): incompatible function arguments. The following argument types are supported:
E 1. (self: jaxlib.xla_extension.WeakrefLRUCache) -> None
E
E Invoked with:
conftest.py:37: TypeError
===================================================================================================== short test summary info ======================================================================================================
ERROR test_interpolation.py::test_derivative - TypeError: cache_clear(): incompatible function arguments. The following argument types are supported:
========================================================================================================= 1 error in 0.18s =========================================================================================================
Sorry, I'm afraid what you're describing isn't clear. Are you describing two different errors here?
For the first one -- what's a MWE, and what's the error message you get?
For the second one -- that looks like some potential issue with the hacky way we clear memory in our tests. It's possible that a JAX update has broken this. What type was it actually invoked with?
Sorry about that, lmk if this is clearer now.
For the type it says :
TypeError: cache_clear(): incompatible function arguments. The following argument types are supported:
E 1. (self: jaxlib.xla_extension.WeakrefLRUCache) -> None
E
E Invoked with:
Regarding the vmap issue: this is expected. vmap_history_fn is a function, so the only way to pass it into a vmap'd function is as something that's broadcast. (Here you have 0 instead.)
Also, note that
@eqx.filter_vmap(args=(0,))
def history_fn(a):
return lambda t : jnp.array([a * t, a * t])
is not allowed. The a is being captured via closure and then "smuggled out" of the vmap region.
Regrading the cache_clear error: I'm not able to reproduce this. Using JAx 0.4.2 and jaxlib 0.4.1 then I find that the tests run as normal.