Marcos Treviso
Marcos Treviso
Having a way to turn on accumulate-gradients/update-freq would be amazing for reproducibility on GPUs. What is the best approach for doing this in JAX?
Hi, Mostafa! Thank you for the quick response. I was able to adapt your code for text classification and it seems like the gradient accumulation is working fine. Since `jax.fori_loop`...
Hi! I got the following results on the test set by using a single GPU (24GB) and setting `accum_steps=batch_size`. All hyperparameters were kept intact, and the only thing that changed...
A simple solution is to install `planar` from a previous commit: ``` !pip install -qqq git+https://github.com/chalk-diagrams/planar@1e06d5894af31984323532092848d87a98278235 !pip install -qqq git+https://github.com/danoneata/chalk@srush-patch-1 !wget -q https://github.com/srush/GPU-Puzzles/raw/main/robot.png !wget -q https://github.com/srush/GPU-Puzzles/raw/main/lib.py ```