lightning-thunder icon indicating copy to clipboard operation
lightning-thunder copied to clipboard

OOM errors for Gemma-7, pythia-12b, Llama-2-13b-hf and Nous-Hermes-13b with FSDP zero3 and 2x8 H100

Open mpatel31415 opened this issue 1 year ago • 10 comments

🐛 Bug

Gemma-7b with FSDP zero3 trained on 2 nodes with 8 H100 each gives OOM error for BS = 2 for both thunder_cudnn and thunder_inductor_cat_cudnn. The same configuration works for inductor. Because of this throughput for Thunder is ~24% lower than for Inductor.

To Reproduce

Steps to reproduce the behavior: It was tested on Slurm cluster. You can contact me on Slack for more details

  1. Create a file script.sh (version with --compile inductor works):
#!/bin/bash
#SBATCH -A YOUR_DETAILS
#SBATCH -p batch
#SBATCH -J YOUR_DETAILS
#SBATCH -N 2
#SBATCH --ntasks-per-node 8
#SBATCH --time 0:29:00
#SBATCH --mail-type=FAIL
#SBATCH --exclusive


IMAGE="INTERNAL_IMAGE:pjnl-20240524"

TRAINING_CMD="python /opt/pytorch/lightning-thunder/thunder/benchmarks/benchmark_litgpt.py \
 --model_name Gemma-7b \
 --micro_batch_size 2 \
 --distributed_mode fsdp \
 --shard_mode zero3 \
 --compile thunder_cudnn 
"
  1. After you are logged into Slurm cluster run:
sbatch script.sh
  1. The results should be visible in slurm-JOB_ID.out file. There is:

torch.OutOfMemoryError: CUDA out of memory. Tried to allocate 1024.00 MiB. GPU 0 has a total capacity of 79.11 GiB of which 604.50 MiB is free. Including non-PyTorch memory, this process has 78.51 GiB memory in use. Of the allocated memory 72.09 GiB is allocated by PyTorch, and 3.52 GiB is reserved by PyTorch but unallocated. If reserved but unallocated memory is large try setting PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True to avoid fragmentation. See documentation for Memory Management (https://pytorch.org/docs/stable/notes/cuda.html#environment-variables)

There is also an error from NVFuser [edit -- moved to #514].

Expected behavior

We should be able to run Gemma-7b with FSDP zero3 trained on 2 nodes with 8 H100 for BS = 2 without OOM error.

Environment

As in the docker image used.

mpatel31415 avatar May 29 '24 10:05 mpatel31415

The same issue (lower possible batch size) is also present for pythia-12b and Nous-Hermes-13b models on 1x8 H100 and 2x8 H100. For Inductor we can use micro batch size 4, but for Thunder and Thunder Inductor only 2. For Llama-2-13b-hf we can is micro batch size 1 for Thunder and 2 for Inductor.

Possibly the same issue is with:

  • falcon-7b (MBS 2 vs 4)
  • pythia-12b with shard_mode zero2 (other cases use zero3) on 1x8 H100 (MBS 1 vs 2)

mpatel31415 avatar May 29 '24 12:05 mpatel31415

triage review —

  • filed https://github.com/Lightning-AI/lightning-thunder/issues/514 for the nvFuser part of this issue (let's track that separately)
  • @kiya00 would you take a look at the Gemma-7 variant of this issue and see where the memory usage is higher than expected?

mruberry avatar Jun 03 '24 19:06 mruberry

I am not sure if that is what happened here but I do see an nvFuser failure pop out in OOM errors. Might not directly be a nvFuser issue but the OOM probably happened during execution and the error gets printed out as part of it. (Or it can really be a nvFuser error)

parthmannan avatar Jun 03 '24 19:06 parthmannan

triage review —

  • filed nvfuser failure #514 for the nvFuser part of this issue (let's track that separately)
  • @kiya00 would you take a look at the Gemma-7 variant of this issue and see where the memory usage is higher than expected?

I reran the failure with nvFuser and the standalone repro and I did not see the error.

kevinstephano avatar Jun 03 '24 19:06 kevinstephano

Yea, it is likely the nvFuser stuff was just printed out because OOM happened during execution. I have seen that before.

parthmannan avatar Jun 04 '24 08:06 parthmannan

Here is a comparison of memory usage on 1 node(8*H100) zero3 vs. single H100 with different number of layers

micr_Bs=2,glb_bs=2 zero3 micr_Bs=2,glb_bs=16
llama-2-13b-hf thunder-cudnn inductor difference thunder-cudnn inductor difference
2 layers 11.58 11.33 0.25 6.52 6.27 0.25
4 layers 18.94 18.35 0.59 10.54 9.95 0.59
10 layers 41.02 39.43 1.59 22.6 21 1.60
20 layers 77.82 74.56 3.26 42.69 39.41 3.28

Note that the 2 columns of "difference" are almost the same, so we can expect that if we solve the memory usage problem for single GPU case, the zero3 problem will also be solved.

Here is a further analysis of the memory usage on a single H100 with micr_bs=2: image

Note that we only focus on the parts that affect the peak memory, and the underlined files are the source files that correspond to the blocks in memory snapshot (I attach them below).

We could see that the difference is not that big and comes from the different fusion parts formed by NVFuser and Triton. e.g.: first triton kernel vs. first nvFusion

# Source Nodes: [add, mul, norm_x, rsqrt, x, x_1, x_normed, x_normed_1, x_normed_2], Original ATen: [aten._to_copy, aten.add, aten.embedding, aten.mean, aten.mul, aten.rsqrt]
triton_red_fused__to_copy_add_embedding_mean_mul_rsqrt_0.run(buf2, primals_20, primals_6, primals_1, buf0, buf3, 8192, 5120, grid=grid(8192), stream=stream0)
[t11, t19] = nvFusion0(t15, t4)
    # t5 = prims.convert_element_type(t4, dtypes.float32)  # t5: "cuda:0 f32[2, 4096, 5120]"
    # t6 = prims.mul(t5, t5)  # t6: "cuda:0 f32[2, 4096, 5120]"
    # t7 = prims.sum(t6, (2,))  # t7: "cuda:0 f32[2, 4096]"
    # t8 = prims.broadcast_in_dim(t7, [2, 4096, 1], [0, 1])  # t8: "cuda:0 f32[2, 4096, 1]"
    # t9 = prims.div(t8, 5120.0)  # t9: "cuda:0 f32[2, 4096, 1]"
    # t10 = prims.add(t9, 1e-05)  # t10: "cuda:0 f32[2, 4096, 1]"
    # t11 = prims.rsqrt(t10)  # t11: "cuda:0 f32[2, 4096, 1]"
    # t12 = prims.broadcast_in_dim(t11, (2, 4096, 5120), (0, 1, 2))  # t12: "cuda:0 f32[2, 4096, 5120]"
    # t13 = prims.mul(t5, t12)  # t13: "cuda:0 f32[2, 4096, 5120]"
    # t17 = prims.convert_element_type(t15, dtypes.float32)  # t17: "cuda:0 f32[2, 4096, 5120]"
    # t18 = prims.mul(t13, t17)  # t18: "cuda:0 f32[2, 4096, 5120]"
    # t19 = prims.convert_element_type(t18, dtypes.bfloat16)  # t19: "cuda:0 bf16[2, 4096, 5120]"

inductor_cforlsiyjtgqyubdy4jnnxyzxm7kz4kmrusqj5z6xvlmqhyj7rti.log trace.log trace_ori.log ("trace.log" is a de-commented version of "trace_ori.log")

cc: @IvanYashchuk @kevinstephano

kiya00 avatar Jun 11 '24 12:06 kiya00

Quick update on this, I done some analysis on the models and this are the minimum memory requirements for thunder_cudnn compile option and micro_batch_size=2 should be:

Model name Total memory in GB per-node GB
Gemma-7b 156.42 x
pythia-12b 126.56 72.25
Llama-2-13b-hf 168.91 x
Nous-Hermes-13b 140.87 x

riccardofelluga avatar Jul 09 '24 08:07 riccardofelluga

By further reducing the n_layers=1 of Llama-2-13b-hf, the memory usage is related to the rematerialization in this case, the only difference part is the memory allocated by [t93, t103, t111] = nvFusion2(t89, t2, t107), the t111 is the same memory buf16 allocated by triton_red_fused__to_copy_add_mean_mul_rsqrt_3. t103 is small, t93 is what we are interested (t93 is the result of an add), for triton I think this part is recalculated between the following 2 kernels: (I think is the add of buf13, buf0)

# Source Nodes: [add_4, mul_7, mul_9, norm_x_1, rsqrt_1, x_2, x_3, x_normed_3, x_normed_4], Original ATen    : [aten._to_copy, aten.add, aten.mean, aten.mul, aten.rsqrt]
456         triton_red_fused__to_copy_add_mean_mul_rsqrt_3.run(buf15, buf13, buf0, primals_8, buf16, 8192, 5120, grid    =grid(8192), stream=stream0)
...
 # Source Nodes: [add_6, mul_11, norm_x_2, rsqrt_2, x_2, x_5, x_6, x_7, x_normed_5, x_normed_6], Original     ATen: [aten._to_copy, aten.add, aten.mean, aten.mul, aten.rsqrt]
474         triton_red_fused__to_copy_add_mean_mul_rsqrt_5.run(buf21, buf23, buf13, buf0, primals_12, buf24, 8192, 51    20, grid=grid(8192), stream=stream0)

But for Thunder, the t93 is passed from nvFusion2 to nvFusion4

...
  [t93, t103, t111] = nvFusion2(t89, t2, t107)
    # t90 = prims.convert_element_type(t89, dtypes.float32)  # t90: "cuda:0 f32[2, 4096, 5120]"
    # t91 = prims.convert_element_type(t2, dtypes.float32)  # t91: "cuda:0 f32[2, 4096, 5120]"
    # t92 = prims.add(t90, t91)  # t92: "cuda:0 f32[2, 4096, 5120]"
    # t93 = prims.convert_element_type(t92, dtypes.bfloat16)  # t93: "cuda:0 bf16[2, 4096, 5120]"
    # t95 = prims.mul(t92, t92)  # t95: "cuda:0 f32[2, 4096, 5120]"
    # t97 = prims.sum(t95, (2,))  # t97: "cuda:0 f32[2, 4096]"
    # t98 = prims.broadcast_in_dim(t97, [2, 4096, 1], [0, 1])  # t98: "cuda:0 f32[2, 4096, 1]"
    # t100 = prims.div(t98, 5120.0)  # t100: "cuda:0 f32[2, 4096, 1]"
    # t102 = prims.add(t100, 1e-05)  # t102: "cuda:0 f32[2, 4096, 1]"
    # t103 = prims.rsqrt(t102)  # t103: "cuda:0 f32[2, 4096, 1]"
    # t104 = prims.broadcast_in_dim(t103, (2, 4096, 5120), (0, 1, 2))  # t104: "cuda:0 f32[2, 4096, 5120]"
    # t105 = prims.mul(t92, t104)  # t105: "cuda:0 f32[2, 4096, 5120]"
    # t109 = prims.convert_element_type(t107, dtypes.float32)  # t109: "cuda:0 f32[2, 4096, 5120]"
    # t110 = prims.mul(t105, t109)  # t110: "cuda:0 f32[2, 4096, 5120]"
    # t111 = prims.convert_element_type(t110, dtypes.bfloat16)  # t111: "cuda:0 bf16[2, 4096, 5120]"
  t112 = torch.nn.functional.linear(t111, t_transformer_h_0_mlp_fc_1_weight, None)  # t112: "cuda:0 bf16[2, 4096, 13824]"
    # t112 = ltorch.linear(t111, t_transformer_h_0_mlp_fc_1_weight, None)  # t112: "cuda:0 bf16[2, 4096, 13824]"
      # t112 = prims.linear(t111, t_transformer_h_0_mlp_fc_1_weight, None)  # t112: "cuda:0 bf16[2, 4096, 13824]"
  t113 = torch.nn.functional.linear(t111, t_transformer_h_0_mlp_fc_2_weight, None)  # t113: "cuda:0 bf16[2, 4096, 13824]"
    # t113 = ltorch.linear(t111, t_transformer_h_0_mlp_fc_2_weight, None)  # t113: "cuda:0 bf16[2, 4096, 13824]"
      # t113 = prims.linear(t111, t_transformer_h_0_mlp_fc_2_weight, None)  # t113: "cuda:0 bf16[2, 4096, 13824]"
  [t127] = nvFusion3(t112, t113)
    # t114 = prims.convert_element_type(t112, dtypes.float32)  # t114: "cuda:0 f32[2, 4096, 13824]"
    # t115 = prims.neg(t114)  # t115: "cuda:0 f32[2, 4096, 13824]"
    # t116 = prims.exp(t115)  # t116: "cuda:0 f32[2, 4096, 13824]"
    # t117 = prims.add(1.0, t116)  # t117: "cuda:0 f32[2, 4096, 13824]"
    # t118 = prims.reciprocal(t117)  # t118: "cuda:0 f32[2, 4096, 13824]"
    # t122 = prims.mul(t114, t118)  # t122: "cuda:0 f32[2, 4096, 13824]"
    # t125 = prims.convert_element_type(t113, dtypes.float32)  # t125: "cuda:0 f32[2, 4096, 13824]"
    # t126 = prims.mul(t122, t125)  # t126: "cuda:0 f32[2, 4096, 13824]"
    # t127 = prims.convert_element_type(t126, dtypes.bfloat16)  # t127: "cuda:0 bf16[2, 4096, 13824]"
  t128 = torch.nn.functional.linear(t127, t_transformer_h_0_mlp_proj_weight, None)  # t128: "cuda:0 bf16[2, 4096, 5120]"
    # t128 = ltorch.linear(t127, t_transformer_h_0_mlp_proj_weight, None)  # t128: "cuda:0 bf16[2, 4096, 5120]"
      # t128 = prims.linear(t127, t_transformer_h_0_mlp_proj_weight, None)  # t128: "cuda:0 bf16[2, 4096, 5120]"
  t189 = torch.unsqueeze(t_transformer_ln_f_weight, 0)  # t189: "cuda:0 bf16[1, 5120]"
    # t189 = ltorch.unsqueeze(t_transformer_ln_f_weight, 0)  # t189: "cuda:0 bf16[1, 5120]"
      # t189 = prims.broadcast_in_dim(t_transformer_ln_f_weight, [1, 5120], [1])  # t189: "cuda:0 bf16[1, 5120]"
  t190 = torch.unsqueeze(t189, 1)  # t190: "cuda:0 bf16[1, 1, 5120]"
    # t190 = ltorch.unsqueeze(t189, 1)  # t190: "cuda:0 bf16[1, 1, 5120]"
      # t190 = prims.broadcast_in_dim(t189, [1, 1, 5120], [0, 2])  # t190: "cuda:0 bf16[1, 1, 5120]"
  del t189
  t146 = Tensor.expand(t190, (2, 4096, 5120))  # t146: "cuda:0 bf16[2, 4096, 5120]"
    # t146 = ltorch.expand(t190, (2, 4096, 5120))  # t146: "cuda:0 bf16[2, 4096, 5120]"
      # t146 = prims.broadcast_in_dim(t190, (2, 4096, 5120), (0, 1, 2))  # t146: "cuda:0 bf16[2, 4096, 5120]"
  del t190
  [t142, t150] = nvFusion4(t93, t128, t146)
    # t130 = prims.convert_element_type(t93, dtypes.float32)  # t130: "cuda:0 f32[2, 4096, 5120]"
    # t129 = prims.convert_element_type(t128, dtypes.float32)  # t129: "cuda:0 f32[2, 4096, 5120]"
    # t131 = prims.add(t129, t130)  # t131: "cuda:0 f32[2, 4096, 5120]"
    # t134 = prims.mul(t131, t131)  # t134: "cuda:0 f32[2, 4096, 5120]"
    # t136 = prims.sum(t134, (2,))  # t136: "cuda:0 f32[2, 4096]"
    # t137 = prims.broadcast_in_dim(t136, [2, 4096, 1], [0, 1])  # t137: "cuda:0 f32[2, 4096, 1]"
    # t139 = prims.div(t137, 5120.0)  # t139: "cuda:0 f32[2, 4096, 1]"
    # t141 = prims.add(t139, 1e-05)  # t141: "cuda:0 f32[2, 4096, 1]"
    # t142 = prims.rsqrt(t141)  # t142: "cuda:0 f32[2, 4096, 1]"
    # t143 = prims.broadcast_in_dim(t142, (2, 4096, 5120), (0, 1, 2))  # t143: "cuda:0 f32[2, 4096, 5120]"
    # t144 = prims.mul(t131, t143)  # t144: "cuda:0 f32[2, 4096, 5120]"
    # t148 = prims.convert_element_type(t146, dtypes.float32)  # t148: "cuda:0 f32[2, 4096, 5120]"
    # t149 = prims.mul(t144, t148)  # t149: "cuda:0 f32[2, 4096, 5120]"
    # t150 = prims.convert_element_type(t149, dtypes.bfloat16)  # t150: "cuda:0 bf16[2, 4096, 5120]"
...
image

cc: @IvanYashchuk

kiya00 avatar Aug 06 '24 14:08 kiya00

Is this then caused by non-optimal segmentation within nvFuser, or is it because nvFuser doesn't support embedding (or another operation), and so it's segmentation of what gets sent to nvFuser?

csarofeen avatar Aug 07 '24 18:08 csarofeen

Hard to say if nvfuser support embedding(or another operation) can help save memory, according to my old analysis on the 2-layer case, the operators appearing in the thunder trace and triton script are different(thunder has its own primitive) and the region they choose to fuse is also different, it's hard to compare the memory usage.

So I reduced it to the 1-layer with less operators, and found in this case the recomputation causes more memory usage (thunder decides to pass the tensor between 2 nvfusion parts instead of recomputing). But what's tricky is that the peak memory could shift with different number of layers, the above reason is not guaranteed to be the cause of OOM of the model.

I'll try some ways to confirm if the reason that causes more memory usage for the reduced model is actually the reason for OOM of the entire model

kiya00 avatar Aug 08 '24 12:08 kiya00

But for Thunder, the t93 is passed from nvFusion2 to nvFusion4

Thank you, @kiya00, for investigating the problem! The min-cut algorithm preferred to send t93 between fusion because with our capacities it's an optimal solution. We can influence this choice by choosing different capacities, in this instance, the recomputation should start from the producer's output since it's just elementwise operations. We can get the desired effect by setting the weight here to be 0.0 (instead of dividing by 2.0): https://github.com/Lightning-AI/lightning-thunder/blob/a59b4efb54ea69980fa7fb63ea78923290e93d5a/thunder/core/rematerialization.py#L345 with the following patch the number of saved tensors for Gemma-7 is the same with Thunder and torch.compile:

diff --git a/thunder/core/rematerialization.py b/thunder/core/rematerialization.py
index 1aa01d52..dada034b 100644
--- a/thunder/core/rematerialization.py
+++ b/thunder/core/rematerialization.py
@@ -342,7 +342,9 @@ def find_cut(
     def add_edges(var):
         var_name = var.name
         weight = get_weight(var)
-        weight = weight / 2.0 if var_name in (x.name for x in producer.args) else weight
+        # If the variable is an input to the producer, it free to be used by the
+        # consumer as well. So we need to set the weight to 0.0
+        weight = 0.0 if var_name in (x.name for x in producer.args) else weight
         add_edge(var_name + "_in", var_name + "_out", capacity=weight)
         for user in combined_consumers._dict.get(var_name, tuple()):
             if user.sym.id in sym_skip_list:

However, this change increases nvFuser's compilation time by a lot because all residual connections now increase the fusion size with more layers.

For Gemma-7B here's an example of the last fusion in the forward pass:
[t5333, t5345] = nvFusion84(t2, t98, t162, t285, t353, t476, t544, t667, t735, t858, t926, t1049, t1117, t1240, t1308, t1431, t1499, t1622, t1690, t1813, t1881, t2004, t2072, t2195, t2263, t2386, t2454, t2577, t2645, t2768, t2836, t2959, t3027, t3150, t3218, t3341, t3409, t3532, t3600, t3723, t3791, t3914, t3982, t4105, t4173, t4296, t4364, t4487, t4555, t4678, t4746, t4869, t4937, t5060, t5128, t5251, t5319, t_transformer_ln_f_weight)
    # t3 = prims.convert_element_type(t2, dtypes.float32)  # t3: "cuda:0 f32[1, 4096, 3072]"
    # t4 = prims.mul(t3, 55.42562584220407)  # t4: "cuda:0 f32[1, 4096, 3072]"
    # t99 = prims.convert_element_type(t98, dtypes.float32)  # t99: "cuda:0 f32[1, 4096, 3072]"
    # t101 = prims.add(t99, t4)  # t101: "cuda:0 f32[1, 4096, 3072]"
    # t163 = prims.convert_element_type(t162, dtypes.float32)  # t163: "cuda:0 f32[1, 4096, 3072]"
    # t165 = prims.add(t163, t101)  # t165: "cuda:0 f32[1, 4096, 3072]"
    # t286 = prims.convert_element_type(t285, dtypes.float32)  # t286: "cuda:0 f32[1, 4096, 3072]"
    # t288 = prims.add(t286, t165)  # t288: "cuda:0 f32[1, 4096, 3072]"
    # t354 = prims.convert_element_type(t353, dtypes.float32)  # t354: "cuda:0 f32[1, 4096, 3072]"
    # t356 = prims.add(t354, t288)  # t356: "cuda:0 f32[1, 4096, 3072]"
    # t477 = prims.convert_element_type(t476, dtypes.float32)  # t477: "cuda:0 f32[1, 4096, 3072]"
    # t479 = prims.add(t477, t356)  # t479: "cuda:0 f32[1, 4096, 3072]"
    # t545 = prims.convert_element_type(t544, dtypes.float32)  # t545: "cuda:0 f32[1, 4096, 3072]"
    # t547 = prims.add(t545, t479)  # t547: "cuda:0 f32[1, 4096, 3072]"
    # t668 = prims.convert_element_type(t667, dtypes.float32)  # t668: "cuda:0 f32[1, 4096, 3072]"
    # t670 = prims.add(t668, t547)  # t670: "cuda:0 f32[1, 4096, 3072]"
    # t736 = prims.convert_element_type(t735, dtypes.float32)  # t736: "cuda:0 f32[1, 4096, 3072]"
    # t738 = prims.add(t736, t670)  # t738: "cuda:0 f32[1, 4096, 3072]"
    # t859 = prims.convert_element_type(t858, dtypes.float32)  # t859: "cuda:0 f32[1, 4096, 3072]"
    # t861 = prims.add(t859, t738)  # t861: "cuda:0 f32[1, 4096, 3072]"
    # t927 = prims.convert_element_type(t926, dtypes.float32)  # t927: "cuda:0 f32[1, 4096, 3072]"
    # t929 = prims.add(t927, t861)  # t929: "cuda:0 f32[1, 4096, 3072]"
    # t1050 = prims.convert_element_type(t1049, dtypes.float32)  # t1050: "cuda:0 f32[1, 4096, 3072]"
    # t1052 = prims.add(t1050, t929)  # t1052: "cuda:0 f32[1, 4096, 3072]"
    # t1118 = prims.convert_element_type(t1117, dtypes.float32)  # t1118: "cuda:0 f32[1, 4096, 3072]"
    # t1120 = prims.add(t1118, t1052)  # t1120: "cuda:0 f32[1, 4096, 3072]"
    # t1241 = prims.convert_element_type(t1240, dtypes.float32)  # t1241: "cuda:0 f32[1, 4096, 3072]"
    # t1243 = prims.add(t1241, t1120)  # t1243: "cuda:0 f32[1, 4096, 3072]"
    # t1309 = prims.convert_element_type(t1308, dtypes.float32)  # t1309: "cuda:0 f32[1, 4096, 3072]"
    # t1311 = prims.add(t1309, t1243)  # t1311: "cuda:0 f32[1, 4096, 3072]"
    # t1432 = prims.convert_element_type(t1431, dtypes.float32)  # t1432: "cuda:0 f32[1, 4096, 3072]"
    # t1434 = prims.add(t1432, t1311)  # t1434: "cuda:0 f32[1, 4096, 3072]"
    # t1500 = prims.convert_element_type(t1499, dtypes.float32)  # t1500: "cuda:0 f32[1, 4096, 3072]"
    # t1502 = prims.add(t1500, t1434)  # t1502: "cuda:0 f32[1, 4096, 3072]"
    # t1623 = prims.convert_element_type(t1622, dtypes.float32)  # t1623: "cuda:0 f32[1, 4096, 3072]"
    # t1625 = prims.add(t1623, t1502)  # t1625: "cuda:0 f32[1, 4096, 3072]"
    # t1691 = prims.convert_element_type(t1690, dtypes.float32)  # t1691: "cuda:0 f32[1, 4096, 3072]"
    # t1693 = prims.add(t1691, t1625)  # t1693: "cuda:0 f32[1, 4096, 3072]"
    # t1814 = prims.convert_element_type(t1813, dtypes.float32)  # t1814: "cuda:0 f32[1, 4096, 3072]"
    # t1816 = prims.add(t1814, t1693)  # t1816: "cuda:0 f32[1, 4096, 3072]"
    # t1882 = prims.convert_element_type(t1881, dtypes.float32)  # t1882: "cuda:0 f32[1, 4096, 3072]"
    # t1884 = prims.add(t1882, t1816)  # t1884: "cuda:0 f32[1, 4096, 3072]"
    # t2005 = prims.convert_element_type(t2004, dtypes.float32)  # t2005: "cuda:0 f32[1, 4096, 3072]"
    # t2007 = prims.add(t2005, t1884)  # t2007: "cuda:0 f32[1, 4096, 3072]"
    # t2073 = prims.convert_element_type(t2072, dtypes.float32)  # t2073: "cuda:0 f32[1, 4096, 3072]"
    # t2075 = prims.add(t2073, t2007)  # t2075: "cuda:0 f32[1, 4096, 3072]"
    # t2196 = prims.convert_element_type(t2195, dtypes.float32)  # t2196: "cuda:0 f32[1, 4096, 3072]"
    # t2198 = prims.add(t2196, t2075)  # t2198: "cuda:0 f32[1, 4096, 3072]"
    # t2264 = prims.convert_element_type(t2263, dtypes.float32)  # t2264: "cuda:0 f32[1, 4096, 3072]"
    # t2266 = prims.add(t2264, t2198)  # t2266: "cuda:0 f32[1, 4096, 3072]"
    # t2387 = prims.convert_element_type(t2386, dtypes.float32)  # t2387: "cuda:0 f32[1, 4096, 3072]"
    # t2389 = prims.add(t2387, t2266)  # t2389: "cuda:0 f32[1, 4096, 3072]"
    # t2455 = prims.convert_element_type(t2454, dtypes.float32)  # t2455: "cuda:0 f32[1, 4096, 3072]"
    # t2457 = prims.add(t2455, t2389)  # t2457: "cuda:0 f32[1, 4096, 3072]"
    # t2578 = prims.convert_element_type(t2577, dtypes.float32)  # t2578: "cuda:0 f32[1, 4096, 3072]"
    # t2580 = prims.add(t2578, t2457)  # t2580: "cuda:0 f32[1, 4096, 3072]"
    # t2646 = prims.convert_element_type(t2645, dtypes.float32)  # t2646: "cuda:0 f32[1, 4096, 3072]"
    # t2648 = prims.add(t2646, t2580)  # t2648: "cuda:0 f32[1, 4096, 3072]"
    # t2769 = prims.convert_element_type(t2768, dtypes.float32)  # t2769: "cuda:0 f32[1, 4096, 3072]"
    # t2771 = prims.add(t2769, t2648)  # t2771: "cuda:0 f32[1, 4096, 3072]"
    # t2837 = prims.convert_element_type(t2836, dtypes.float32)  # t2837: "cuda:0 f32[1, 4096, 3072]"
    # t2839 = prims.add(t2837, t2771)  # t2839: "cuda:0 f32[1, 4096, 3072]"
    # t2960 = prims.convert_element_type(t2959, dtypes.float32)  # t2960: "cuda:0 f32[1, 4096, 3072]"
    # t2962 = prims.add(t2960, t2839)  # t2962: "cuda:0 f32[1, 4096, 3072]"
    # t3028 = prims.convert_element_type(t3027, dtypes.float32)  # t3028: "cuda:0 f32[1, 4096, 3072]"
    # t3030 = prims.add(t3028, t2962)  # t3030: "cuda:0 f32[1, 4096, 3072]"
    # t3151 = prims.convert_element_type(t3150, dtypes.float32)  # t3151: "cuda:0 f32[1, 4096, 3072]"
    # t3153 = prims.add(t3151, t3030)  # t3153: "cuda:0 f32[1, 4096, 3072]"
    # t3219 = prims.convert_element_type(t3218, dtypes.float32)  # t3219: "cuda:0 f32[1, 4096, 3072]"
    # t3221 = prims.add(t3219, t3153)  # t3221: "cuda:0 f32[1, 4096, 3072]"
    # t3342 = prims.convert_element_type(t3341, dtypes.float32)  # t3342: "cuda:0 f32[1, 4096, 3072]"
    # t3344 = prims.add(t3342, t3221)  # t3344: "cuda:0 f32[1, 4096, 3072]"
    # t3410 = prims.convert_element_type(t3409, dtypes.float32)  # t3410: "cuda:0 f32[1, 4096, 3072]"
    # t3412 = prims.add(t3410, t3344)  # t3412: "cuda:0 f32[1, 4096, 3072]"
    # t3533 = prims.convert_element_type(t3532, dtypes.float32)  # t3533: "cuda:0 f32[1, 4096, 3072]"
    # t3535 = prims.add(t3533, t3412)  # t3535: "cuda:0 f32[1, 4096, 3072]"
    # t3601 = prims.convert_element_type(t3600, dtypes.float32)  # t3601: "cuda:0 f32[1, 4096, 3072]"
    # t3603 = prims.add(t3601, t3535)  # t3603: "cuda:0 f32[1, 4096, 3072]"
    # t3724 = prims.convert_element_type(t3723, dtypes.float32)  # t3724: "cuda:0 f32[1, 4096, 3072]"
    # t3726 = prims.add(t3724, t3603)  # t3726: "cuda:0 f32[1, 4096, 3072]"
    # t3792 = prims.convert_element_type(t3791, dtypes.float32)  # t3792: "cuda:0 f32[1, 4096, 3072]"
    # t3794 = prims.add(t3792, t3726)  # t3794: "cuda:0 f32[1, 4096, 3072]"
    # t3915 = prims.convert_element_type(t3914, dtypes.float32)  # t3915: "cuda:0 f32[1, 4096, 3072]"
    # t3917 = prims.add(t3915, t3794)  # t3917: "cuda:0 f32[1, 4096, 3072]"
    # t3983 = prims.convert_element_type(t3982, dtypes.float32)  # t3983: "cuda:0 f32[1, 4096, 3072]"
    # t3985 = prims.add(t3983, t3917)  # t3985: "cuda:0 f32[1, 4096, 3072]"
    # t4106 = prims.convert_element_type(t4105, dtypes.float32)  # t4106: "cuda:0 f32[1, 4096, 3072]"
    # t4108 = prims.add(t4106, t3985)  # t4108: "cuda:0 f32[1, 4096, 3072]"
    # t4174 = prims.convert_element_type(t4173, dtypes.float32)  # t4174: "cuda:0 f32[1, 4096, 3072]"
    # t4176 = prims.add(t4174, t4108)  # t4176: "cuda:0 f32[1, 4096, 3072]"
    # t4297 = prims.convert_element_type(t4296, dtypes.float32)  # t4297: "cuda:0 f32[1, 4096, 3072]"
    # t4299 = prims.add(t4297, t4176)  # t4299: "cuda:0 f32[1, 4096, 3072]"
    # t4365 = prims.convert_element_type(t4364, dtypes.float32)  # t4365: "cuda:0 f32[1, 4096, 3072]"
    # t4367 = prims.add(t4365, t4299)  # t4367: "cuda:0 f32[1, 4096, 3072]"
    # t4488 = prims.convert_element_type(t4487, dtypes.float32)  # t4488: "cuda:0 f32[1, 4096, 3072]"
    # t4490 = prims.add(t4488, t4367)  # t4490: "cuda:0 f32[1, 4096, 3072]"
    # t4556 = prims.convert_element_type(t4555, dtypes.float32)  # t4556: "cuda:0 f32[1, 4096, 3072]"
    # t4558 = prims.add(t4556, t4490)  # t4558: "cuda:0 f32[1, 4096, 3072]"
    # t4679 = prims.convert_element_type(t4678, dtypes.float32)  # t4679: "cuda:0 f32[1, 4096, 3072]"
    # t4681 = prims.add(t4679, t4558)  # t4681: "cuda:0 f32[1, 4096, 3072]"
    # t4747 = prims.convert_element_type(t4746, dtypes.float32)  # t4747: "cuda:0 f32[1, 4096, 3072]"
    # t4749 = prims.add(t4747, t4681)  # t4749: "cuda:0 f32[1, 4096, 3072]"
    # t4870 = prims.convert_element_type(t4869, dtypes.float32)  # t4870: "cuda:0 f32[1, 4096, 3072]"
    # t4872 = prims.add(t4870, t4749)  # t4872: "cuda:0 f32[1, 4096, 3072]"
    # t4938 = prims.convert_element_type(t4937, dtypes.float32)  # t4938: "cuda:0 f32[1, 4096, 3072]"
    # t4940 = prims.add(t4938, t4872)  # t4940: "cuda:0 f32[1, 4096, 3072]"
    # t5061 = prims.convert_element_type(t5060, dtypes.float32)  # t5061: "cuda:0 f32[1, 4096, 3072]"
    # t5063 = prims.add(t5061, t4940)  # t5063: "cuda:0 f32[1, 4096, 3072]"
    # t5129 = prims.convert_element_type(t5128, dtypes.float32)  # t5129: "cuda:0 f32[1, 4096, 3072]"
    # t5131 = prims.add(t5129, t5063)  # t5131: "cuda:0 f32[1, 4096, 3072]"
    # t5252 = prims.convert_element_type(t5251, dtypes.float32)  # t5252: "cuda:0 f32[1, 4096, 3072]"
    # t5254 = prims.add(t5252, t5131)  # t5254: "cuda:0 f32[1, 4096, 3072]"
    # t5320 = prims.convert_element_type(t5319, dtypes.float32)  # t5320: "cuda:0 f32[1, 4096, 3072]"
    # t5322 = prims.add(t5320, t5254)  # t5322: "cuda:0 f32[1, 4096, 3072]"
    # t5325 = prims.mul(t5322, t5322)  # t5325: "cuda:0 f32[1, 4096, 3072]"
    # t5327 = prims.sum(t5325, (2,))  # t5327: "cuda:0 f32[1, 4096]"
    # t5328 = prims.broadcast_in_dim(t5327, [1, 4096, 1], [0, 1])  # t5328: "cuda:0 f32[1, 4096, 1]"
    # t5330 = prims.div(t5328, 3072.0)  # t5330: "cuda:0 f32[1, 4096, 1]"
    # t5332 = prims.add(t5330, 1e-05)  # t5332: "cuda:0 f32[1, 4096, 1]"
    # t5333 = prims.rsqrt(t5332)  # t5333: "cuda:0 f32[1, 4096, 1]"
    # t5334 = prims.broadcast_in_dim(t5333, (1, 4096, 3072), (0, 1, 2))  # t5334: "cuda:0 f32[1, 4096, 3072]"
    # t5335 = prims.mul(t5322, t5334)  # t5335: "cuda:0 f32[1, 4096, 3072]"
    # t5337 = prims.convert_element_type(t_transformer_ln_f_weight, dtypes.float32)  # t5337: "cuda:0 f32[3072]"
    # t5339 = prims.add(1.0, t5337)  # t5339: "cuda:0 f32[3072]"
    # t5340 = prims.convert_element_type(t5339, dtypes.bfloat16)  # t5340: "cuda:0 bf16[3072]"
    # t5341 = prims.broadcast_in_dim(t5340, (1, 4096, 3072), (2,))  # t5341: "cuda:0 bf16[1, 4096, 3072]"
    # t5343 = prims.convert_element_type(t5341, dtypes.float32)  # t5343: "cuda:0 f32[1, 4096, 3072]"
    # t5344 = prims.mul(t5335, t5343)  # t5344: "cuda:0 f32[1, 4096, 3072]"
    # t5345 = prims.convert_element_type(t5344, dtypes.bfloat16)  # t5345: "cuda:0 bf16[1, 4096, 3072]"

Instead of setting the weight/capacity to 0.0, it's better to scale it with 0.1 or some other small value to limit the rematerialization ability to recompute the whole residual computation history. The following patch gives memory savings and doesn't blow up the fusion size and compilation time:

diff --git a/thunder/core/rematerialization.py b/thunder/core/rematerialization.py
index 1aa01d5..ab2fd6e 100644
--- a/thunder/core/rematerialization.py
+++ b/thunder/core/rematerialization.py
@@ -342,7 +342,7 @@ def find_cut(
     def add_edges(var):
         var_name = var.name
         weight = get_weight(var)
-        weight = weight / 2.0 if var_name in (x.name for x in producer.args) else weight
+        # If the variable is an input to the producer, it is almost free to be used by the
+        # consumer as well. So we need to scale the weight to 0.1.
+        # Setting the weight to 0.0 hurts in the long chain of consumers
+        weight = weight * 0.1 if var_name in (x.name for x in producer.args) else weight
         add_edge(var_name + "_in", var_name + "_out", capacity=weight)
         for user in combined_consumers._dict.get(var_name, tuple()):
             if user.sym.id in sym_skip_list:

IvanYashchuk avatar Aug 22 '24 15:08 IvanYashchuk

I noticed that similar problem is present for falcon-40b, Platypus-30B and vicuna-33b-v1.3 models.

mpatel31415 avatar Nov 12 '24 10:11 mpatel31415

@kiya00, could you please take over this issue and implement the rematerialization tweak (the second patch) described in https://github.com/Lightning-AI/lightning-thunder/issues/474#issuecomment-2305027560? Please check that it gives memory savings for Gemma-7B and no regression for other models. 0.1 seems like a good scaling factor, but it's arbitrary. If you have time it's a good idea to study the effect of choosing the scaling factor between 0.05 (can't be zero because then the pass chooses to rematerialize too much) to 0.5 (current value).

IvanYashchuk avatar Nov 19 '24 09:11 IvanYashchuk

Synced with @IvanYashchuk , according the updated benchmark result in 20241107, we'll focus on solving the OOM for Platypus-30B, falcon-40b, vicuna-33b-v1.3 with Thunder and ThunderFX backend by choosing proper scaling factor

kiya00 avatar Nov 20 '24 15:11 kiya00

Had a quick check on 1node(8*H100) torchrun --nproc_per_node=8 --nnodes=1 /opt/pytorch/lightning-thunder/thunder/benchmarks/benchmark_litgpt.py --model_name Platypus-30B --micro_batch_size 1 --distributed_mode fsdp --shard_mode zero3 --compile thunder --max_iters 20 --warmup_iters 5 went OOM with the factor=0.05, falcon-40b and vicuna-33b-v1.3 with the same parameter also went OOM cc: @IvanYashchuk

kiya00 avatar Nov 20 '24 18:11 kiya00

Alright, something else is going on there then. Are these models just bigger or do they have other config differences from the ones that work?

IvanYashchuk avatar Nov 21 '24 13:11 IvanYashchuk

Platypus-30B: Number of parameters: 4.07B Gemma-7b: Number of parameters: 1.17B so I think Platypus-30B is bigger. when n_layers=20: factor=0.05 => Saved for backward size: 11596.01 MiB, Saved for backward number of tensors: 450, Memory used: 34.13 GB factor=0.5 => Saved for backward size: 12558.01 MiB, Saved for backward number of tensors: 487, Memory used: 34.13 GB it seems the factor doesn't affect the peak memory

kiya00 avatar Nov 26 '24 14:11 kiya00

The initial problem of this issue has been solved, Gemma-7, pythia-12b, Llama-2-13b-hf and Nous-Hermes-13b can run without OOM with Thunder and ThunderFX backend on 2*8H100, the current problem is to solve the OOM for Platypus-30B, falcon-40b, vicuna-33b-v1.3 with Thunder and ThunderFX backend, which is covered in https://github.com/Lightning-AI/lightning-thunder/issues/1233, so I'll close this issue.

kiya00 avatar Nov 28 '24 15:11 kiya00