DeepSpeed icon indicating copy to clipboard operation
DeepSpeed copied to clipboard

Question about Deepspeed-Domino

Open GoldenStain opened this issue 3 months ago • 0 comments

I have thoroughly reviewed your code and noticed that in DominoTransformerLayer.forward(), you make extensive use of dist.all_reduce. However, to my knowledge, it does not have a corresponding backward implementation. Could you explain why the results are still correct?

class DominoTransformerLayer(DominoModule):
    """A domino single transformer layer.
    [s, b, h] -> [s, b, h]
    """

    def __init__(self,
                 config,
                 mpu,
                 apply_rotary_pos_emb,
                 layer_number,
                 layer_type=LayerType.encoder,
                 self_attn_mask_type=AttnMaskType.causal,
                 drop_path_rate=0.):

        super(DominoTransformerLayer, self).__init__()
        self.layer_number = layer_number
        self.layer_type = layer_type

        self.apply_residual_connection_post_layernorm \
            = config.apply_residual_connection_post_layernorm

        self.llama_model = False

        self.input_layernorm = torch.nn.LayerNorm(config.hidden_size, eps=config.layernorm_epsilon)

        # Self attention.
        self.self_attention = ShardedAttention(config,
                                               mpu,
                                               apply_rotary_pos_emb,
                                               layer_number,
                                               attention_type=AttnType.self_attn,
                                               attn_mask_type=self_attn_mask_type)

        self.hidden_dropout = config.hidden_dropout

        self.post_attention_layernorm = torch.nn.LayerNorm(config.hidden_size, eps=config.layernorm_epsilon)

        # MLP
        ffn_hidden_size = config.ffn_hidden_size
        if config.gated_linear_unit:
            ffn_hidden_size *= 2

        self.output_size_c = config.ffn_hidden_size
        self.input_size_c = config.hidden_size
        self.input_size_r = config.ffn_hidden_size
        self.output_size_r = self.input_size_c

        tp_world_size = mpu.get_tensor_model_parallel_world_size()
        self.TP_group = mpu.get_tensor_model_parallel_group()
        self.output_size_per_partition = self.output_size_c // tp_world_size
        self.input_size_per_partition = self.input_size_r // tp_world_size

        self.linear_fc1 = DominoAsyncColumnParallelLinear(self.input_size_c,
                                                          self.output_size_per_partition,
                                                          mpu.get_tensor_model_parallel_group(),
                                                          config=config,
                                                          init_method=config.init_method,
                                                          bias=config.add_bias_linear)

        self.mlp_activation_func = F.gelu

        self.linear_fc2 = RowParallelLinearNoComm(self.input_size_per_partition,
                                                  self.output_size_r,
                                                  config=config,
                                                  init_method=config.output_layer_init_method,
                                                  bias=config.add_bias_linear,
                                                  skip_bias_add=True)

        self.bias_dropout_add_func = bias_dropout_add(self.hidden_dropout)

    def forward(self, hidden_states, attention_mask, rotary_pos_emb=None):

        hidden_states0, hidden_states1 = hidden_states

        layernorm_output0 = self.input_layernorm(hidden_states0)
        layernorm_output0 = _Wait_bwd_comm(layernorm_output0, DominoUtil.HANDLE_DIC, DominoUtil.BATCH_0)

        # Micro batch 0: attention
        attention_output0, attention_bias0 = self.self_attention(layernorm_output0,
                                                                 attention_mask,
                                                                 DominoUtil.BATCH_0,
                                                                 rotary_pos_emb=rotary_pos_emb)

        fwd_handle0 = dist.all_reduce(attention_output0, group=self.TP_group, async_op=True)
        # End of Micro batch 0: attention

        # Micro batch 1: attention
        layernorm_output1 = self.input_layernorm(hidden_states1)
        layernorm_output1 = _Wait_bwd_comm(layernorm_output1, DominoUtil.HANDLE_DIC, DominoUtil.BATCH_1)

        attention_output1, attention_bias1 = self.self_attention(layernorm_output1,
                                                                 attention_mask,
                                                                 DominoUtil.BATCH_1,
                                                                 rotary_pos_emb=rotary_pos_emb)
        fwd_handle1 = dist.all_reduce(attention_output1, group=self.TP_group, async_op=True)

        # Micro batch 0: Residual connection.
        fwd_handle0.wait()
        if self.apply_residual_connection_post_layernorm:
            residual0 = layernorm_output0
        else:
            residual0 = hidden_states0

        layernorm_input0 = self.bias_dropout_add_func(attention_output0, attention_bias0, residual0)

        layernorm_output0 = self.post_attention_layernorm(layernorm_input0)
        layernorm_output0 = _Wait_bwd_comm(layernorm_output0, DominoUtil.HANDLE_DIC, DominoUtil.BATCH_0)

        if self.apply_residual_connection_post_layernorm:
            residual0 = layernorm_output0
        else:
            residual0 = layernorm_input0
        # End of Micro batch 0: Residual connection.

        # ------------ MLP ------------
        # Micro batch 0: MLP
        output0, _ = self.linear_fc1(layernorm_output0, DominoUtil.HANDLE_DIC, DominoUtil.BATCH_0)
        output0 = self.mlp_activation_func(output0)

        # Micro batch 1: Residual connection.
        fwd_handle1.wait()
        if self.apply_residual_connection_post_layernorm:
            residual1 = layernorm_output1
        else:
            residual1 = hidden_states1

        layernorm_input1 = self.bias_dropout_add_func(attention_output1, attention_bias1, residual1)

        layernorm_output1 = self.post_attention_layernorm(layernorm_input1)
        layernorm_output1 = _Wait_bwd_comm(layernorm_output1, DominoUtil.HANDLE_DIC, DominoUtil.BATCH_1)

        if self.apply_residual_connection_post_layernorm:
            residual1 = layernorm_output1
        else:
            residual1 = layernorm_input1
        # End of Micro batch 1: Residual connection.

        hidden_states0, last_mlp_bias = self.linear_fc2(output0)
        fwd_handle0 = dist.all_reduce(hidden_states0, group=self.TP_group, async_op=True)
        # End of Micro batch 0: MLP

        # Micro batch 1: MLP
        output1, _ = self.linear_fc1(layernorm_output1, DominoUtil.HANDLE_DIC, DominoUtil.BATCH_1)
        output1 = self.mlp_activation_func(output1)

        hidden_states1, last_mlp_bias = self.linear_fc2(output1)

        fwd_handle1 = dist.all_reduce(hidden_states1, group=self.TP_group, async_op=True)
        # End of Micro batch 1: MLP

        # ------------  End of MLP ------------

        fwd_handle0.wait()
        hidden_states0 = self.bias_dropout_add_func(hidden_states0, last_mlp_bias, residual0)

        fwd_handle1.wait()
        hidden_states1 = self.bias_dropout_add_func(hidden_states1, last_mlp_bias, residual1)

        return hidden_states0, hidden_states1

GoldenStain avatar Oct 28 '25 16:10 GoldenStain