Gemma issues identified by the Unsloth team / impact on mlx code? (shared on our discord as well)
posted this in the discord (https://discord.gg/pEPVK6gGfW)
but thanks to the awesome work of the unsloth team, they've identified some bugs in gemma implementations across the ecosystem: https://unsloth.ai/blog/gemma-bugs
i think these are the potential fixes to the mlx-lm examples repo, but would love a second pair of eyes -
RMSNorm: fp32 at the beginning and result back down at the end to the weight's dtype set dtype of weight in RMSNorm class to fp32
rope issues: unsloth guys say it needs to be int32 and not bfloat, but i dont see gemma.py in mlx examples handling it explicitly - maybe not an issue? cant tell. same with "RoPE is sensitive to a*(1/x) vs a/x"
gelu needs to be approx tanh
looking at mlx.nn, looks like it defaults to exact when you dont pass in a param. gemma.py does this:
return self.down_proj(nn.gelu(self.gate_proj(x)) * self.up_proj(x))
which I assume means it's using exact, and needs to pass in 'precise'