chex
chex copied to clipboard
Add support for value assertions in jitted functions.
In my experience, when chex reports max traces exceeded, it's usually because of me passing parameters to the function with different shapes or data types. Is it possible for chex...
Hi, Thanks for making this awesome library! Is it possible to specify fields in the chex.dataclass definitions to not include certain fields? This is a feature supported in flax https://flax.readthedocs.io/en/latest/_modules/flax/struct.html#dataclass...
Internal changes
Internal changes
Static type checkers like pyright, mypy, etc. will think `chex.dataclass`-decorated dataclass has a constructor with no parameters. Example: ``` @chex.dataclass(frozen=True) class Foo: a: int b: int ``` However, `Foo()` is...
I spend quite some time figuring out why code in a large codebase was so slow, only to find out that `jit` was disabled throughout the entire project. This was...
The [`_with_pmap`](https://github.com/deepmind/chex/blob/70350fd8fb0937034c8da7fd1dd47de7aad0747a/chex/_src/variants.py#L428) function accepts `static_argnums` as a parameter, but not `static_argnames`. This is inconsistent with other variants, such as [`with_jit`](https://github.com/deepmind/chex/blob/70350fd8fb0937034c8da7fd1dd47de7aad0747a/chex/_src/variants.py#L346) and [`with_device`](https://github.com/deepmind/chex/blob/70350fd8fb0937034c8da7fd1dd47de7aad0747a/chex/_src/variants.py#L376). Crucially, this prevents to test methods that require...
```python from jax import jit from jax.lax import scan from tjax import IntegralNumeric, RealNumeric from tjax.dataclasses import dataclass, field import chex def f(carry, _): return carry + 1.0, None @jit...
Hello! I'm interested in using pydantic's recursive constructor / asdict functionality, but jax.jit-ed functions give the following error: ``` Argument '_Pydantic_OptimConfig_93971134241088(.. SOMETHING HERE...)' of type is not a valid JAX...