tutel icon indicating copy to clipboard operation
tutel copied to clipboard

how to save checkpoint when use data parallel and moe expert

Open Satan012 opened this issue 4 years ago • 7 comments

Satan012 avatar Feb 08 '22 14:02 Satan012

You can use general way to save & reload models that are in any kind of distributed settings, with each peer program holding one slice of checkpoint files. https://github.com/microsoft/tutel/pull/88/files

ghostplant avatar Feb 08 '22 16:02 ghostplant

I only use moe_layer funtion in my code and the parameters of all experts are same. dist_rank=2,my cuda version is 10.1 and pytorch version is 1.8

self.ffn_text = tutel_moe.moe_layer( gate_type={'type': 'top', 'k': 2, 'fp32_gate': True}, experts={'type': 'ffn', 'count_per_node': 1, 'hidden_size_per_expert': hidden_features, 'activation_fn': lambda x: F.relu(x)}, model_dim=in_features, scan_expert_func=lambda name, param: setattr(param, 'skip_allreduce', True), seeds=(1, parallel_env['dist_rank'] + 1, 1), a2a_ffn_overlap_degree=1, ).to(parallel_env['local_device'])

Satan012 avatar Feb 09 '22 05:02 Satan012

Do you skip allreduce on expert parameters? If not, their value will become the same.

ghostplant avatar Feb 09 '22 08:02 ghostplant

i have change my DistributedDataParallel to torch.nn.parallel.DistributedDataParallel, and the problem have been solved, but there have i new error....

Traceback (most recent call last): File "Train_tutel.py", line 354, in <module> main() File "Train_tutel.py", line 251, in main loss = user_module(train_input, oodn_output=oodn_output) File "/home/pai/lib/python3.6/site-packages/torch/nn/modules/module.py", line 903, in _call_impl result = self.forward(*input, **kwargs) File "/home/pai/lib/python3.6/site-packages/torch/nn/parallel/distributed.py", line 714, in forward if self.reducer._rebuild_buckets(): RuntimeError: Expected to have finished reduction in the prior iteration before starting a new one. This error indicates that your module has parameters that were not used in producing loss. You can enable unused parameter detection by (1) passing the keyword argument find_unused_parameters=Truetotorch.nn.parallel.DistributedDataParallel; (2) making sure all forwardfunction outputs participate in calculating loss. If you already have done the above two steps, then the distributed data parallel module wasn't able to locate the output tensors in the return value of your module'sforwardfunction. Please include the loss function and the structure of the return value offorward of your module when reporting this issue (e.g. list, dict, iterable).

Satan012 avatar Feb 10 '22 02:02 Satan012

Is it related to moe_layer? Does this suggestion help? https://github.com/pytorch/pytorch/issues/22436

ghostplant avatar Feb 10 '22 06:02 ghostplant

when i set find_unused_parameters=True, all parameters of the experts will not be updated

Satan012 avatar Feb 10 '22 07:02 Satan012

It should be related to the usage of torch.nn.parallel.DistributedDataParallel. If you try helloworld_ddp.py with multiple gates, as the model also contains parameters that don't contribute to loss, it is verified to work well with torch's DistributedDataParallel. e.g.

diff --git a/tutel/examples/helloworld_ddp.py b/tutel/examples/helloworld_ddp.py
index c6464fb..fe0e382 100755
--- a/tutel/examples/helloworld_ddp.py
+++ b/tutel/examples/helloworld_ddp.py
@@ -66,7 +66,7 @@ class ExampleModel(torch.nn.Module):
         self._ddp_params_and_buffers_to_ignore = list()

         self._moe_layer = tutel_moe.moe_layer(
-            gate_type = {'type': 'top', 'k': top_value, 'fp32_gate': args.fp32_gate},
+            gate_type = [{'type': 'top', 'k': top_value, 'fp32_gate': args.fp32_gate}, {'type': 'top', 'k': top_value, 'fp32_gate': args.fp32_gate}],
             experts = {'type': 'ffn', 'count_per_node': num_local_experts, 'hidden_size_per_expert': hidden_size, 'activation_fn': lambda x: F.relu(x)},
             model_dim = model_dim,
             scan_expert_func = lambda name, param: setattr(param, 'skip_allreduce', True),

ghostplant avatar Feb 10 '22 07:02 ghostplant