BitNet-Transformers
BitNet-Transformers copied to clipboard
Question about BitLinear Implementation
Hi, I have a doubt in your BitLinear.forward() implementation. The BitNet paper says the output should be the form as ; y = binarized_weight(W) @ AbsMaxQuant(LN(x)) * bettagamma/Q_b (LN is layer normalization as the paper describes). However, in your implementation, the output looks implemented as ; y = AbsMaxQuant(binarized_weight(W) @ x) Why do you drop LN(x) and switch the order of paper's implementation? And there isn't dequantization with rescaling with bettagamma/Q_b in your implementation. Can I get some ideas behind your implementation? If I misunderstand your implementation, please correct me.