Gerson Kroiz

Results 25 comments of Gerson Kroiz

Would it be simpler to have a separate file for tpu finetuning (for example, `adapter_tpu.py`)?

cc @Liyang90 for review

TODO: if we decide to keep these changes in `adapter.py`, we also need to make the same adjustments in `adapter_v2.py`

When using TPUs, the current code in `adapter.py` won't work when using more than 8 cores (4 chips). For example, let's say we want to run finetune on `v4-64` (64...

@carmocca I noticed this PR is now closed. Do we no longer need to xla changes?

Hi @ronghanghu, I wanted to confirm whether the autograd discrepancy in `nn.Linear()` is specific to FSDP or applies to all instances of `nn.Linear()` when using PyTorch/XLA. cc @JackCaoG

@cgarciae, Here is the setup: ``` # Install newest version of JAX and jaxlib gcloud alpha compute tpus tpu-vm ssh $TPU_NAME --project=$PROJECT_ID --zone=$ZONE --worker=all --command='pip install "jax[tpu]==0.4.16" -f https://storage.googleapis.com/jax-releases/libtpu_releases.html' #...

@andsteing the setup for tpu v5e is a bit different than v2. Could you try with the flag `--tpu-ubuntu2204-base`, this should have a newer python version.

@andsteing ``` export PROJECT_ID=your_project_ID export ACCELERATOR_TYPE=v5e-16 export ZONE=us-west4-a export RUNTIME_VERSION=v2-alpha-tpuv5-lite export SERVICE_ACCOUNT=your_service_account export TPU_NAME=your_tpu_name export QUEUED_RESOURCE_ID=your_queued_resource_id export VALID_DURATION=1d gcloud alpha compute tpus queued-resources create ${QUEUED_RESOURCE_ID} \ --node-id ${TPU_NAME} \ --project...

@andsteing using optax 0.1.7 we run into a new error: Log: ``` Using ssh batch size of 4. Attempting to SSH into 1 nodes with a total of 4 workers....