Set TF_FORCE_GPU_ALLOW_GROWTH=true by default
This is needed to be able to run Fuji v2 70B on GPU without GPU memory OOMs.
@kelvin-zou can likely confirm whether this should be the default or not.
Hmm but we also set a lot of TPU environment variables in launch.py without any if statements. I don't think there is a better place since it needs to happen before jax is started?
Would you prefer this?
if instance_type.startswith("gpu"):
# Prevent GPU OOM issues due to TF taking up all the GPU memory.
# Reference: https://stackoverflow.com/a/54927279
os.environ.setdefault("TF_FORCE_GPU_ALLOW_GROWTH", "true")
Hi @samos123 is this PR still relevant?
This pull request has been automatically marked as stale because it has been inactive for 60 days. It will be closed in 7 days if no further activity occurs. If you would like to continue working on this, please remove the stale label or leave a comment.
This pull request was closed because it has been inactive for more than 7 days since being marked as stale. Please feel free to reopen it if you would like to continue.