graphcast icon indicating copy to clipboard operation
graphcast copied to clipboard

Slow Gencast Inference on GPUs after 1st run

Open v-weh opened this issue 1 year ago • 4 comments

Many thanks for making Gencast code and weights public!

I managed to tweak the code in “gencast_demo_cloud_vm.ipynb” and got it running on a 8-GPUs (H100) cluster, to generate forecasts up to 15 days with 12 hours interval, with 8 ensembles.

First run took around ~35 minutes which is expected, however when I ran it the second time, it still took around ~30 - 35 minutes. Not sure if this is expected behaviour because I thought there is a fixed-time cost only when running the first time, and further runs will take only about ~8 minutes?

Or is that only applicable to using TPUs or only when I generate a single forecast e.g 15 days out rather than the entire sequence?

v-weh avatar Dec 06 '24 21:12 v-weh

Hello!

Apologies, the demo notebook implementations have an oversight here. You might notice that upon re-running the rollout cell (# @title Autoregressive rollout (loop in python)), that recompilation is triggered.

This is because the xarray_jax.pmap(run_forward_jitted, dim="sample") should only be called once. To fix this separate the call out of the cell:

%%
run_forward_pmap = xarray_jax.pmap(run_forward_jitted, dim="sample")

%%
# New cell that upon second run will no longer compile again
# ... code as before ...
for chunk in rollout.chunked_prediction_generator_multiple_runs(
    predictor_fn=run_forward_pmap,
    rngs=rngs,
    inputs=eval_inputs,
    targets_template=eval_targets * np.nan,
    forcings=eval_forcings,
    num_steps_per_chunk = 1,
    num_samples = num_ensemble_members,
    pmap_devices=jax.local_devices()
    ):
    chunks.append(chunk)
predictions = xarray.combine_by_coords(chunks`

Will send a fix to the repo ASAP, but thought I'd respond here first.

Thanks!

Andrew

andrewlkd avatar Dec 09 '24 16:12 andrewlkd

Many thanks for the help! Just to confirm: does the 8 minutes stated in the paper refer to time taken to:

  1. generate single 30 steps out forecast i.e (+360h) or
  2. full 30 steps out forecasts with all the intermediate steps with 12 hours interval i.e (+12h, +24h, +36h ... +360h) ?

v-weh avatar Dec 09 '24 18:12 v-weh

The latter (on a TPUv5 and without compilation/tracing costs).

(Note that since we produce our forecasts autoregressively, the time taken to generate the 30th step - i.e. your former option - is the same as the time to produce all the intermediate steps since they are needed to feed back in as inputs!).

andrewlkd avatar Dec 09 '24 18:12 andrewlkd

Thank you! I haven't yet seen an increase in run time tbh, but will re-run some experiments and see if I actually made mistakes, will get back to you!

v-weh avatar Dec 10 '24 10:12 v-weh