pytorch-transformer icon indicating copy to clipboard operation
pytorch-transformer copied to clipboard

why the Encoder has a norm layer on its final output?

Open SeekPoint opened this issue 1 year ago • 1 comments

class Encoder(nn.Module):

def __init__(self, features: int, layers: nn.ModuleList) -> None:
    super().__init__()
    self.layers = layers
    self.norm = LayerNormalization(features)

def forward(self, x, mask):
    for layer in self.layers:
        x = layer(x, mask)
    return self.norm(x)

SeekPoint avatar Dec 01 '24 17:12 SeekPoint

@hkproj I think the implementation deviates from the architecture proposed in the paper. The paper states that normalization is applied after each sublayer i.e. there is the output of the multi-head attention added with the residual connection are normalized, which are then passed as input to the feed forward network, whose output is added with the residual connection which is then normalized again. Here's a pseudo code of what it should look like

class EncoderBlock(nn.Module):

    def __init__(self, features: int, self_attention_block: MultiHeadAttentionBlock, feed_forward_block: FeedForwardBlock, dropout: float) -> None:
        super().__init__()
        self.norm = LayerNormalization(features)
        self.self_attention_block = self_attention_block
        self.feed_forward_block = feed_forward_block
        self.residual_connections = nn.ModuleList([ResidualConnection(features, dropout) for _ in range(2)])

    def forward(self, x, src_mask):
        x = self.residual_connections[0](x, lambda x: self.self_attention_block(x, x, x, src_mask))
        x = self.norm(x)
        x = self.residual_connections[1](x, self.feed_forward_block)
        return self.norm(x)
    
class Encoder(nn.Module):

    def __init__(self, features: int, layers: nn.ModuleList) -> None:
        super().__init__()
        self.layers = layers

    def forward(self, x, mask):
        for layer in self.layers:
            x = layer(x, mask)
        return x

PT-10 avatar Jan 26 '25 06:01 PT-10