onnxruntime icon indicating copy to clipboard operation
onnxruntime copied to clipboard

Trade operator recompute for memory

Open pengwa opened this issue 3 years ago • 1 comments

Description: Operator-level recompute

This PR adds an optional capability trading additional re-computation for better memory efficiency. Specifically, a pre-defined operator list used to iterate the Graph, to find some stashed activations whose lifetime across forward and backward pass.

At the first phase, it still needs manual config or even tries to see which kinds of operators should be enabled for recompute. Ideally, we can analyze the allocation planner for memory peak, and pick up some candidates during the peak, further (which can be a TODO if more models benifits from this).

Baseline

On a 1P model (DEBERTA V2), sequence length 256, training with 16 A100 GPUs. With latest main branch, we can run batch size 16, and the maximum batch size < 32. So 16 is usually chosen by data scientists. 65% of 40GB memory is used during training. The SamplesPerSec=479.2543353561354.

image

With this PR

Gelu is recomputed for saving memory peak, batch size 32 can be run. The 97% of 40GB A100 is used, the SamplesPerSec=562.041593991271 (1.17X of baseline).

image

Transformation Summary

[ RUN      ] MemoryAlleviationTests.GeluRecompute
2022-09-04 15:46:52.033469118 [W:onnxruntime:Default, memory_alleviation.cc:369 ApplyImpl] Node Gelu_1(Gelu) can be recomputed
2022-09-04 15:46:52.033524613 [W:onnxruntime:Default, memory_alleviation.cc:446 PrintSummary]
MemoryAlleviation Summary:
        Type config:
                Dropout-0, Gelu-1, Tile-0
        =================================
        Subgraph: Gelu +        Recompute
                Shape   Frequency
                input1_dim0 x 500 x     1
        --------------------------------
        =================================

"Type config:" whether recompute is enabled by users. 0 - disable, 1- enable. "Subgraph" means what kind of subgraph will be recomputed, in this case, it is a single node "Gelu", and it will be "Recompute". "Shape && Frequency" means, for this recompute, one tensor of size (batch size, 500) will be saved because it will be recomputed.

Even users did not enable any operator recompute, those summary also will be printed for us to understand whether recompute would be helpful when running this model.

worker-0: 2022-09-06 04:51:54.202336019 [W:onnxruntime:, memory_alleviation.cc:457 PrintSummary]
worker-0: MemoryAlleviation Summary:
worker-0:       Type config:
worker-0:               Dropout-0, Gelu-1, Tile-1
worker-0:       =================================
worker-0:       Subgraph: BiasGelu +    Recompute
worker-0:               Shape   Frequency
worker-0:               attention_mask1_dim0 x 256 x 6144 x     48
worker-0:       --------------------------------
worker-0:       Subgraph: Where + BitmaskDropout +      Disabled
worker-0:               Shape   Frequency
worker-0:               attention_mask1_dim0 x 256 x 1536 x     2
worker-0:       --------------------------------
worker-0:       Subgraph: Where + BitmaskDropout + Reshape +    Disabled
worker-0:               Shape   Frequency
worker-0:               24*attention_mask1_dim0 x 256 x 256 x   48
worker-0:       --------------------------------
worker-0:       Subgraph: Tile +        Recompute
worker-0:               Shape   Frequency
worker-0:               24*attention_mask1_dim0 x 512 x 64 x    96
worker-0:       --------------------------------
worker-0:       Subgraph: BitmaskDropout +      Disabled
worker-0:               Shape   Frequency
worker-0:               attention_mask1_dim0 x 256 x 1536 x     2
worker-0:               attention_mask1_dim0 x 1536 x   2
worker-0:       --------------------------------
worker-0:       =================================

Motivation and Context

  • Why is this change required? What problem does it solve?
  • If it fixes an open issue, please link to the issue here.

pengwa avatar Sep 04 '22 15:09 pengwa

Can you explain what config the user needs to set to enable this feature? From the code I gather user needs to set env variable ORTMODULE_ENABLE_MEMORY_ALLEVIATION to "Dropout:0,Gelu:1,Tile:0" is that it? Also how does the user understand the following:

  1. Such an optimization is available
  2. Which ones to enable and which ones to disable (Dropout\Gelu\Tile)
  3. Effect of enabling this optimization technique

askhade avatar Sep 16 '22 17:09 askhade

Can you explain what config the user needs to set to enable this feature? From the code I gather user needs to set env variable ORTMODULE_ENABLE_MEMORY_ALLEVIATION to "Dropout:0,Gelu:1,Tile:0" is that it? Also how does the user understand the following:

  1. Such an optimization is available If some models did not use all GPU memory, or ORT has capability to run batch size 13, 31, etc, we had potential to enable the recompute to double the batch size for better perf.
  2. Which ones to enable and which ones to disable (Dropout\Gelu\Tile) As you mentioned above, ORTMODULE_ENABLE_MEMORY_ALLEVIATION currently is the env var we can control operator-level recompute. As the first trial on some models, we can enable all of them, if we found the saving is more than we need, we can just disable some one by one to avoid added recompute latency.
  3. Effect of enabling this optimization technique We hope to fully utilize the hardware, both compute and memory bandwidth, so having bigger batch size (power of 2) would be the most direct results using this techniques.

As we talked offline, I put some answers inline. :)

pengwa avatar Sep 29 '22 12:09 pengwa

Couple of questions:

  1. export ORTMODULE_ENABLE_MEMORY_ALLEVIATION="Mul+FusedMatMul+Cast+Unsqueeze+Unsqueeze+Cast+Sub+Mul+Add+BiasSoftmaxDropout+Cast+:1:-1,BiasGelu+:1:-1,BitmaskDropout+Cast+:1:-1,FusedMatMul+:1:-1,Cast+:1:-1,Mul+Add+:1:-1,Mul+Sub+:1:-1"

Here what does +:1:-1 mean?

Subgraph: FusedMatMul+ [1,0]: AlleviationType: Recompute [1,0]: Patterns: [1,0]: PatternShape:input_ids_dim0 x input_ids_dim1 x 4,096 x Frequency:24

a. By Frequency do you mean "Occurences" ? b. Why is this stderr?

askhade avatar Nov 01 '22 04:11 askhade

Can you add a readme\developer doc here: https://github.com/microsoft/onnxruntime/tree/main/docs explaining how to enable and experiment with this feature. You can refine the PR description a little and add it as a doc along with this PR

askhade avatar Nov 01 '22 04:11 askhade

Thanks a lot @askhade @baijumeswani for the efforts to review!!!

pengwa avatar Nov 03 '22 05:11 pengwa