Davis Yoshida
Davis Yoshida
This is just an implementation of VGG-16 in Haiku + pretrained imagenet weights. I'm not sure if it's substantial enough to make the list, but I figured I'd make a...
I thought the following would work ```python def _fn(x): return hk.GroupNorm(groups=32, data_format='channels_first', )(x) x = jnp.zeros((128, 100, 100)) # C x H x W fn = hk.transform(_fn) params = fn.init(jax.random.PRNGKey(0),...
### Description Running the following with JAX 0.4.25 causes an AttributeError: ```python import jax import jax.numpy as jnp import jax.experimental.pallas as pl def do_dot(a, b, out_ref): out_ref[0, 0, :, :]...
We're planning on doing some work using some of the pretrained models, are the training hyperparameters documented anywhere?