Patrick Nguyen
Patrick Nguyen
Fix typo
Fix kwags -> kwargs. No behavior change.
So, it's nice to be able to have runtime asserts, but we can't have them in a pure function. In X86, when the processor encounters an error, it raises the...
I tried ``` def AssertShape(x: jnp.array, shape) -> None: if not jnp.array_equal(x.shape, shape): raise ValueError(f'Shape mismatch: found {x.shape}, expected: {shape}') ``` and got ``` jax.errors.TracerBoolConversionError: Attempted boolean conversion of traced...
Not sure why, but ``` if summary_utils.write_summary_every_n_steps(...``` in `lingvo/jax/train.py` fails because the metric is not replicated. I have set `unreplicate_metrics=False` and it works for me.
This should be translated into `# Implicit dependency on...` in `lingvo/jax/BUILD` since users already use the pip package.