Ivy Zheng

Results 29 comments of Ivy Zheng

> Looks good! Should we add a simple test? Test added!

PR https://github.com/google/flax/pull/2697 is moving `tf_errors` inside the file system shim `io.py` and dropping TF from the require import list of `flax.training.checkpoints`. Now the `metrics/tensorboard.py` should be the only place where...

`jax.experimental` has an [implementation](https://github.com/google/jax/blob/main/jax/experimental/pallas/ops/tpu/flash_attention.py) of FlashAttention, written by Pallas kernels and therefore usable in both GPU and TPU. We can probably upstream this to Flax attention if `jax.experimental` doesn't scare...

Looks like the two attention kernels are for different platforms - one for [TPU](https://github.com/google/jax/blob/main/jax/experimental/pallas/ops/tpu/flash_attention.py) and another for [GPU](https://github.com/google/jax/blob/main/jax/experimental/pallas/ops/attention.py). The implementations are slightly different for performance reasons, and they are also...

Dropout is not currently available in Pallas kernels, as it is yet to support PRNG keys. Causal masking can be turned on with `causal=True`, and the TPU version has attention...

`jax.random.PRNGKey` is a deprecated version of `jax.random.key`, so the usage of latter is preferred. You are using a very old JAX version that still uses `PRNGKey`. And it looks like...

Ah, JAX documentation is a bit outdated... One of the JAX owners actually wrote the PR that converted all Flax examples to use `jax.random.key`: https://github.com/google/flax/pull/3337 I might open a PR...

Sorry for the late reply - I am actually debating whether to use `with mesh` or the `mesh_sharding()` util in the guide. It's quite common for larger model library to...

The module in Flax is called `flax.linen` - it's just that people often write `import flax.linen as nn` for shorthand.

On the mistaken `0.7.1` release, no PyPI package was built and uploaded because the version check did not pass. And since it didn't really impacted PyPI, it shouldn't be a...