Sebastian Bodenstein

Results 10 issues of Sebastian Bodenstein

I can't find any license information for the Atari ROMs shipped in this repository. Is there any information that could clarify this? Thanks!

@zhreshold: do you have any plans to upstream these custom layers to MXNet?

Using an A100 Colab: ``` import jax import jax.numpy as jnp print(jnp.array(1).device().device_kind) @jax.jit def f(x, y): return jnp.einsum('bqc,bkc->bqk', x, y) x_bfloat = jnp.ones((384 * 4, 384, 16), dtype=jnp.bfloat16) x_float =...

bug
P1 (soon)
NVIDIA GPU

First, thanks for this great repo, its helped a lot of people over the years! I wanted to mention that 11.3 has an official package for interfacing with MongoDB: https://reference.wolfram.com/language/MongoLink/guide/MongoLinkOperations.html...

### Description When defining custom kernels, there are three distinct kernels for a `jax.custom_vjp`: `f`, `f_fwd`, `f_bwd`. When inside a `jax.vjp` and `jax.remat`, all three kernels should called: first `f`,...

bug

### Description `jax.nn.dot_product_attention` does the first dot product with `preferred_element_type=jnp.float32` (see [here](https://github.com/jax-ml/jax/blob/7f655972c47658768b6ecce752fa29c3a64a824a/jax/_src/nn/functions.py#L846)). For BF16 inputs, this prevents an unnecessary downcast to BF16 (can improve numerical stability, and has no extra...

bug

Why does optimistix use a custom jvp rather than the `jax.lax.custom_root` primitive? https://docs.jax.dev/en/latest/_autosummary/jax.lax.custom_root.html

question

``` import equinox as eqx class Error(eqx.Enumeration): error = 'error' success = 'success' def f(x: Error) -> Error: return x f(Error.error) ``` gives under mypy: ``` 11: error: Argument 1...

Suppose you have a dataclass: ``` @jaxtyped(typechecker=typechecker) @dataclass class MyDataclass: x: Float[Array, "n"] @property def plus1(self) -> Float[Array, "n"]: return self.x + 1.0 ``` There is no way AFAIK how...

question

This ``` import jax from jax import export from jax import numpy as jnp import jaxtyping as jt import typeguard @jt.jaxtyped(typechecker=typeguard.typechecked) def f( x: jt.Float[jt.Array, "*#B"], ) -> jt.Float[jt.Array, "*#B"]:...