ChatGLM3
ChatGLM3 copied to clipboard
RMSNorm的不同实现方式
System Info / 系統信息
torch版本:2.12
Who can help? / 谁可以帮助到您?
No response
Information / 问题信息
- [X] The official example scripts / 官方的示例脚本
- [ ] My own modified scripts / 我自己修改的脚本和任务
Reproduction / 复现过程
class RMSNorm(torch.nn.Module):
def __init__(self, normalized_shape, eps=1e-5, device=None, dtype=None, **kwargs):
super().__init__()
self.weight = torch.nn.Parameter(torch.empty(normalized_shape, device=device, dtype=dtype))
self.eps = eps
def forward(self, hidden_states: torch.Tensor):
input_dtype = hidden_states.dtype
variance = hidden_states.to(torch.float32).pow(2).mean(-1, keepdim=True)
hidden_states = hidden_states * torch.rsqrt(variance + self.eps)
return (self.weight * hidden_states).to(input_dtype)
chatglm的RMSNorm实现中,weight用的是torch.empty的随机初始化。而llama的RMSNorm实现中,用的是torch.ones全一初始化。请问,chatglm用torch.empty是有随机初始化缩放系数的考虑嘛?
Expected behavior / 期待表现
能介绍一下两种实现方式的优劣吗?