DeepSpeed icon indicating copy to clipboard operation
DeepSpeed copied to clipboard

Optimize zero3 fetch params using all_reduce

Open deepcharm opened this issue 1 year ago • 9 comments

  • Use all_reduce instead of all_gather to fetch module parameters. This improves performance by reducing the overhead of concatenation and slicing, which are no longer required.
  • Instead, all tensors views are created prior to the collective (all_reduce), so upon its completion only the parameter status is updated.
  • The behavior is enabled via a new boolean flag under the section "zero_optimization": { "stage3_use_all_reduce_for_fetch_params": true }
  • By default the optimization is not enabled.

deepcharm avatar Apr 16 '24 14:04 deepcharm

@deepcharm, thanks for this interesting approach. Can you share some observed performance gains?

tjruwase avatar Apr 16 '24 14:04 tjruwase

@deepcharm, thanks for this interesting approach. Can you share some observed performance gains?

@tjruwase We have observed around 9% performance gain on HPU in BERT workloads.

deepcharm avatar Apr 16 '24 15:04 deepcharm

Hi @deepcharm

Thx for the PR. Just curious why allreduce could be faster than allgather? allreduce basically is doing reduce-scatter + all-gather. Could we just make allgather as coalesced version to remove the overhead of concatenation and slicing?

GuanhuaWang avatar Apr 16 '24 17:04 GuanhuaWang

Hi @deepcharm

Thx for the PR. Just curious why allreduce could be faster than allgather? allreduce basically is doing reduce-scatter + all-gather. Could we just make allgather as coalesced version to remove the overhead of concatenation and slicing?

Hi @GuanhuaWang, you're right the proposed approach indeed adds some communication overhead. The main idea is to re-arrange the layout of the sharded pieces in the flat buffer to achieve overall perf boost.

Hopefully, the attached slides below help clarify the benefits (less Host side overhead, smaller memory peak, etc). Please let me know if that answers your questions.

1) Current Approach

Current_Approach

2) Proposed Optimization

Proposal

3) Comparison

Comparison

deepcharm avatar Apr 18 '24 16:04 deepcharm

Hi @deepcharm Thx for the PR. Just curious why allreduce could be faster than allgather? allreduce basically is doing reduce-scatter + all-gather. Could we just make allgather as coalesced version to remove the overhead of concatenation and slicing?

Hi @GuanhuaWang, you're right the proposed approach indeed adds some communication overhead. The main idea is to re-arrange the layout of the sharded pieces in the flat buffer to achieve overall perf boost.

Hopefully, the attached slides below help clarify the benefits (less Host side overhead, smaller memory peak, etc). Please let me know if that answers your questions.

1) Current Approach

Current_Approach

2) Proposed Optimization

Proposal

3) Comparison

Comparison

Hi @deepcharm , these slides are cool and make sense to me. But as 2) Proposed Optimization, it showed removing unnecessay data concat&copy by avoiding params interleaving of allgather (Not allreduce). Allreduce is what confuses me, we don't do any sum/avg operation on collected weights right?

GuanhuaWang avatar Apr 23 '24 01:04 GuanhuaWang

image

@deepcharm, I was not aware that narrow, cat, copy operations on device tensors incurred high CPU overhead. I will like to learn more. Can you share the reason? How did you discover this? Can you share some repro/test code for this? Thanks!

tjruwase avatar Apr 23 '24 13:04 tjruwase

image @deepcharm, I was not aware that narrow, cat, copy operations on device tensors incurred high CPU overhead. I will like to learn more. Can you share the reason? How did you discover this? Can you share some repro/test code for this? Thanks!

@tjruwase, we've seen this phenomenon in large models where looping over many params causes significant CPU overhead. Possibly this issue is more specific for accelerators such as HPU. We will create a repro script and share with you.

deepcharm avatar May 02 '24 11:05 deepcharm

@tjruwase, we've seen this phenomenon in large models where looping over many params causes significant CPU overhead. Possibly this issue is more specific for accelerators such as HPU. We will create a repro script and share with you.

@deepcharm, very interesting, thanks for the explanation. I look forward to learning more from the repro script. I think it might be a great documentation for performance debugging of zero3 on accelerators.

tjruwase avatar May 07 '24 13:05 tjruwase

Hi @tjruwase, for some reason the PR has been removed from the merge-queue. Can you please re-add it? Thanks

deepcharm avatar May 09 '24 12:05 deepcharm