[Feature Request] Task-level Optimization with Distributed Data Parallelization
Motivation
Task-level parallelization for multi-host multi-process optimization.
Batch-level parallelization can be implemented easily by wrapping the network (nn.Module) with:
-
torch.nn.DataParallel(single-host multi-GPUs) (SPMD) -
torch.nn.parallel.DistributedDataParallel(multi-host multi-GPUs)
However, for algorithms that require task-level parallelization, non of the above solutions work. torch.nn.DataParallel and torch.nn.parallel.DistributedDataParallel are used for module-level parallelization. The wrapper will replicate the user module to multiple copies, then do the forward pass in parallel. For task-level parallelization, each task needs to maintain its own model parameters and (optional) training data. The module parameters may be different across tasks.
Solution
functorch.vmap + distributed data parallel optimization.
Alternatives
N/A
Additional context
Resources:
PyTorch:
- Tutorial: PyTorch Distributed Overview
- Tutorial: Distributed Data Parallel
- API: Module level Data Parallel
torch.nn.DataParallel(SPMD) - API: Module level Distributed Data Parallel
torch.nn.parallel.DistributedDataParallel - API: PyTorch Distributed Optimizers
torch.distributed.optim - API: Vectorization map
functorch.vmap
JAX:
- Tutorial: Named axes and easy-to-revise parallelism
- API: Vectorization map
jax.vmap - API: Parallel map
jax.pmap(SPMD) - API (Experimental):
jax.experimental.maps.xmap - Tutorial: Using JAX in multi-host and multi-process environments
Checklist
- [X] I have checked that there is no similar issue in the repo (required)