Gerson Kroiz
Gerson Kroiz
This PR adds changes specific for finetuning on TPUs. Used TPU v4-8 with ``` log_interval = 1 devices = 4 batch_size = 64 / devices micro_batch_size = 4 gradient_accumulation_steps =...
PR to add FLOPS numbers for benchmarking when using TPUs. Follow up to https://github.com/Lightning-AI/lit-gpt/pull/147#discussion_r1230009025
Implements Pytorch XLA FSDP for TPUs. This PR is based on https://github.com/Lightning-AI/lightning/pull/17421, but focuses only on Trainer related changes. Use the `XLAFSDPStrategy` to use FSDP on TPUs. Fixes https://github.com/Lightning-AI/lightning/issues/13209
Using newer versions of Flax, there seems to be an outdate API call in the imagenet test. First, when testing with the most recent stable release of jax, 0.4.16 and...