jax icon indicating copy to clipboard operation
jax copied to clipboard

multi-machine allreduce

Open christopherhesse opened this issue 6 years ago • 14 comments

Hi! I am looking to do fast multi-machine allreduce and broadcast operations when using JAX and MPI.

Here is a script that should be similar to my workload which I ran on 8 GCE instances with 8 V100 GPUs each and a 32 Gbit network:

https://gist.github.com/christopherhesse/b5d141a59d9648caab191d9ff6333117

I ran it using mpich:

mpiexec -f <hosts file> python <path to script>

The output looks like this:

num_params 1
compute             : min_elapsed 0.000424  avg_elapsed 0.026451  max_elapsed 0.259608
device_to_host      : min_elapsed 0.000070  avg_elapsed 0.000106  max_elapsed 0.000298
allreduce           : min_elapsed 0.000209  avg_elapsed 0.002230  max_elapsed 0.018252
num_params 16000000
compute             : min_elapsed 0.006838  avg_elapsed 0.023782  max_elapsed 0.155499
device_to_host      : min_elapsed 0.123953  avg_elapsed 0.135843  max_elapsed 0.163817
allreduce           : min_elapsed 0.505218  avg_elapsed 0.592024  max_elapsed 0.640469

So about 600 ms per allreduce for 16M float32s.

If I use nccl-tests with MPI support (make MPI=1):

mpiexec -f <hosts file> ./nccl-tests/build/all_reduce_perf -b 1M -e 64M -f 2 -g 1 -c 0

The output looks like this:

[0] #                                                     out-of-place                       in-place          
[0] #       size         count    type   redop     time   algbw   busbw  error     time   algbw   busbw  error
[0] #        (B)    (elements)                     (us)  (GB/s)  (GB/s)            (us)  (GB/s)  (GB/s)       
[0]      1048576        262144   float     sum[0]    5328.1    0.20    0.39    N/A[0]    3856.1    0.27    0.54    N/A
[0]      2097152        524288   float     sum[0]    6751.2    0.31    0.61    N/A[0]    6132.7    0.34    0.67    N/A
[0]      4194304       1048576   float     sum[0]     11100    0.38    0.74    N/A[0]     10899    0.38    0.76    N/A
[0]      8388608       2097152   float     sum[0]    9818.9    0.85    1.68    N/A[0]    9351.1    0.90    1.77    N/A
[0]     16777216       4194304   float     sum[0]     17219    0.97    1.92    N/A[0]     17121    0.98    1.93    N/A
[0]     33554432       8388608   float     sum[0]     35836    0.94    1.84    N/A[0]     36609    0.92    1.80    N/A
[0]     67108864      16777216   float     sum[0]     73365    0.91    1.80    N/A[0]     78911    0.85    1.67    N/A

Which looks like 80 ms for 16M float32s.

For my particular training setup, I am seeing ~600 ms spent doing the allreduce, out of ~800 ms total per training loop, so improving this could improve the runtime of my script substantially.

The two ways that seem most promising to me would be:

  1. Use XLA's existing NCCL support or extend it to do this call through XLA

  2. Use pointers to GPU memory to call NCCL from Python (not sure if this would encounter weird issues with XLA also using CUDA)

What do you guys think?

christopherhesse avatar Jul 10 '19 01:07 christopherhesse

Thanks for raising this! I used mpi4py a bit in grad school and loved it, so it will be really satisfying to get this working well.

I don't have much to offer right now but just wanted to collect some of our clues in one place. @hawkinsp just added a way to grab raw XLA:GPU memory pointers to our Python XLA client so that we could explore option 2: here's the TF commit.

We might need an XLA expert to weigh in on option 1, but AIUI XLA:GPU's NCCL support is for the single-host setting. That said, XLA:TPU's multi-replica computations can span accelerators across multiple hosts, so there could be some path forward there.

mattjj avatar Jul 10 '19 02:07 mattjj

(By the way, we need to update the TF commit that our repo points to before our build process will build a version of XLA with that update. Also we'll need to update the jaxlib wheels.)

mattjj avatar Jul 10 '19 02:07 mattjj

Thanks for the quick response! I'll give the pointers a try once XLA is updated.

christopherhesse avatar Jul 10 '19 17:07 christopherhesse

@mattjj any idea when XLA will be updated next?

christopherhesse avatar Jul 18 '19 23:07 christopherhesse

Related to this issue, is there a way to tell JAX to use a specific GPU without setting CUDA_VISIBLE_DEVICES?

christopherhesse avatar Jul 19 '19 00:07 christopherhesse

@mattjj any idea when XLA will be updated next?

Sorry, we're getting behind because almost all of the team is in London this week.

I just kicked off a build for Linux machines, which should be done in an hour or so; @hawkinsp could you do the macOS build, or will that have to wait for us to return to the colonies?

Related to this issue, is there a way to tell JAX to use a specific GPU without setting CUDA_VISIBLE_DEVICES?

No, not yet. Would you want to place different computations on different devices, or just set a program-global parameter to control this? (The former is more general than the latter, and we plan to add that soon, but maybe the latter would be a quicker fix that covers your use case.)

mattjj avatar Jul 19 '19 06:07 mattjj

Thanks for kicking off that build!

Program-global parameter would work for me. The reason is that it looks like I have to set CUDA_VISIBLE_DEVICES to a different value for NCCL, it's possible that I can work around this by setting the env var only during NCCL initialization. Should I file a new issue if it turns out that the global parameter is important to me?

christopherhesse avatar Jul 19 '19 17:07 christopherhesse

No need for a new issue; several others are keen on having a way to control jit device assignment, so it’s on my mind already.

(Haven’t uploaded the wheels yet.)

mattjj avatar Jul 19 '19 22:07 mattjj

It looks like I can get NCCL to work even when using CUDA_VISIBLE_DEVICES, but performance is noticeably impacted possibly because it can't see all the GPUs on the machine at once.

As a result, it's not required for this issue, but it would be very nice to control device assignment at even a global manner.

christopherhesse avatar Jul 20 '19 01:07 christopherhesse

The unsafe_buffer_pointer() method should be available in Jaxlib 0.1.22. Please experiment with it (although I wouldn't consider it a final API.)

hawkinsp avatar Jul 22 '19 12:07 hawkinsp

The buffer pointer works great! Should I leave this issue open until there's some sort of control over device assignment? It looks like my setup with NCCL should be even faster once that is supported by JAX.

Specifically, I want to confine JAX to using a single GPU (so no allocating a ton of memory on the other GPUs) without setting CUDA_VISIBLE_DEVICES.

christopherhesse avatar Jul 22 '19 23:07 christopherhesse

@mattjj I wanted to mention that I may have found a workaround for this CUDA_VISIBLE_DEVICES issue that does not require any JAX changes. I will try it out and update this issue with the result.

christopher-hesse avatar Jul 24 '19 06:07 christopher-hesse

Specifically the fix for that is to upgrade to CUDA 10.1: https://github.com/NVIDIA/nccl/blob/master/src/transport/p2p.cc#L75

It looks like JAX does not have a jaxlib for CUDA 10.1 though, do I need to build that myself?

christopher-hesse avatar Jul 24 '19 06:07 christopher-hesse

Looks like not only does tensorflow not support CUDA 10.1, but neither do our compute clusters, so nevermind.

christopherhesse avatar Jul 24 '19 23:07 christopherhesse

I think we can consider this one fixed! JAX supports multiworker GPU computation, and you can perform an all-reduce using pmap and psum.

hawkinsp avatar Aug 12 '22 20:08 hawkinsp