torchopt icon indicating copy to clipboard operation
torchopt copied to clipboard

[Feature Request] Task-level Optimization with Distributed Data Parallelization

Open XuehaiPan opened this issue 3 years ago • 0 comments

Motivation

Task-level parallelization for multi-host multi-process optimization.

Batch-level parallelization can be implemented easily by wrapping the network (nn.Module) with:

However, for algorithms that require task-level parallelization, non of the above solutions work. torch.nn.DataParallel and torch.nn.parallel.DistributedDataParallel are used for module-level parallelization. The wrapper will replicate the user module to multiple copies, then do the forward pass in parallel. For task-level parallelization, each task needs to maintain its own model parameters and (optional) training data. The module parameters may be different across tasks.

Solution

functorch.vmap + distributed data parallel optimization.

Alternatives

N/A

Additional context

Resources:

PyTorch:

JAX:

Checklist

  • [X] I have checked that there is no similar issue in the repo (required)

XuehaiPan avatar Aug 10 '22 08:08 XuehaiPan