ChatGLM3 icon indicating copy to clipboard operation
ChatGLM3 copied to clipboard

RMSNorm的不同实现方式

Open trundleyrg opened this issue 1 year ago • 0 comments

System Info / 系統信息

torch版本:2.12

Who can help? / 谁可以帮助到您?

No response

Information / 问题信息

  • [X] The official example scripts / 官方的示例脚本
  • [ ] My own modified scripts / 我自己修改的脚本和任务

Reproduction / 复现过程

class RMSNorm(torch.nn.Module):
    def __init__(self, normalized_shape, eps=1e-5, device=None, dtype=None, **kwargs):
        super().__init__()
        self.weight = torch.nn.Parameter(torch.empty(normalized_shape, device=device, dtype=dtype))
        self.eps = eps

    def forward(self, hidden_states: torch.Tensor):
        input_dtype = hidden_states.dtype
        variance = hidden_states.to(torch.float32).pow(2).mean(-1, keepdim=True)
        hidden_states = hidden_states * torch.rsqrt(variance + self.eps)

        return (self.weight * hidden_states).to(input_dtype)

chatglm的RMSNorm实现中,weight用的是torch.empty的随机初始化。而llama的RMSNorm实现中,用的是torch.ones全一初始化。请问,chatglm用torch.empty是有随机初始化缩放系数的考虑嘛?

Expected behavior / 期待表现

能介绍一下两种实现方式的优劣吗?

trundleyrg avatar May 28 '24 12:05 trundleyrg