axlearn icon indicating copy to clipboard operation
axlearn copied to clipboard

Set TF_FORCE_GPU_ALLOW_GROWTH=true by default

Open samos123 opened this issue 1 year ago • 1 comments

This is needed to be able to run Fuji v2 70B on GPU without GPU memory OOMs.

@kelvin-zou can likely confirm whether this should be the default or not.

samos123 avatar Sep 24 '24 21:09 samos123

Hmm but we also set a lot of TPU environment variables in launch.py without any if statements. I don't think there is a better place since it needs to happen before jax is started?

Would you prefer this?

if instance_type.startswith("gpu"):
    # Prevent GPU OOM issues due to TF taking up all the GPU memory.
    # Reference: https://stackoverflow.com/a/54927279
    os.environ.setdefault("TF_FORCE_GPU_ALLOW_GROWTH", "true")

samos123 avatar Sep 24 '24 21:09 samos123

Hi @samos123 is this PR still relevant?

changlan avatar Jul 26 '25 00:07 changlan

This pull request has been automatically marked as stale because it has been inactive for 60 days. It will be closed in 7 days if no further activity occurs. If you would like to continue working on this, please remove the stale label or leave a comment.

github-actions[bot] avatar Oct 17 '25 02:10 github-actions[bot]

This pull request was closed because it has been inactive for more than 7 days since being marked as stale. Please feel free to reopen it if you would like to continue.

github-actions[bot] avatar Oct 27 '25 02:10 github-actions[bot]