[WIP]Refactor resource management
Description
Better handling of resources namely TPU so there is one place to init it and gracefully release
Checklist
Before submitting this PR, please make sure (put X in square brackets):
- [X] I have performed a self-review of my code. For an optional AI review, add the
gemini-reviewlabel. - [X] I have necessary comments in my code, particularly in hard-to-understand areas.
- [X] I have run end-to-end tests tests and provided workload links above if applicable.
- [X] I have made or will make corresponding changes to the doc if needed.
Can you clarify the motivation for this a bit more? Is there a particular multi-tenancy setup that you are running that needs this resource management? My initial thought is that this resource management should be done by a higher layer whichever is scheduling workloads that needs to share resources like GKE etc.
Also, to note is that this logic seems highly coupled with multi controller JAX and would likely not be needed with pathways (there is no jax.distributed.initialize() needed for pathways)
The logic here is to have some resource management. Right now jax.distributed.initialize is called but jax.distributed.shutdown is not. Also jax.distributed.initialize is inside some functions in maxtext. For a colab it might result in non-releasing TPU even when the job is done
This PR has been automatically marked as stale because it has not had recent activity. It will be closed soon if no further activity occurs. Thank you for your contributions.
This PR was closed because it has been inactive for a while. Please reopen it if you are still working on it.