Johannes E. M. Mosig
Johannes E. M. Mosig
Sorry, @guilherme-mendes I didn't see the message. I typically only see this when you click the re-review arrows and request another review. I'll have a look today :)
One could use [AQT](https://github.com/google/aqt) for this, if penzai would expose the `dot_general` function somehow. But, e.g., `Linear` is implemented in terms of `jnp.tensordot`, which uses `lax.dot_general` [under the hood](https://github.com/jax-ml/jax/blob/f3224caf462eb4f5618d16d37f122f15c919b4ae/jax/_src/numpy/tensor_contractions.py#L548), and...
I had the same thought! Your naming is better, though :) I'll try to implement the AQT layers when I find the time. But I think doing so will involve...
It looks like using AQT directly is a bit more tricky than I thought, as AQT objects carry around state for calibration and the AQT code generally seems to be...
@demoncoder-crypto This looks like the successor to AQT and might be interesting to look into: https://github.com/google/qwix
It seems to work when I ignore both, but I don't know yet what happens when `use_cache = True`
I managed to load the frozen layer-stacked model into GPU by stacking each parameter group on CPU and pushing it to GPU one at a time. Unfortunately, running a forward...
Running ```py import os os.environ["XLA_PYTHON_CLIENT_MEM_FRACTION"] = ".98" ``` before any of this does not help
Just passing the example input through the first layer (embedding lookup, outside the layer stack) does not result in an OOM, so it really has something to do with the...
Thanks for the reply! > I wonder if the initialization issue and the runtime issue have the same cause or different causes. I am unsure about this, too. > Do...