BitNet
BitNet copied to clipboard
Drop-in replacement needed for f.linear?
class BitLinearInference(nn.Module):
def __init__(self,
in_features: int,
out_features: int,
):
super().__init__()
self.in_f = in_features
self.out_f = out_features
self.register_buffer("w", torch.empty((out_features, in_features)))
self.register_buffer("w_scale", torch.empty((1,), dtype=torch.float32))
self.norm = nn.RMSNorm(
normalized_shape=in_features,
eps=1e-5,
elementwise_affine=True
)
def forward(self, x: Tensor) -> Tensor:
x_norm = self.norm(x)
x_int, x_scale = quantize_activation(x_norm)
y_int = F.linear(x_int, self.w)
y = y_int * (self.w_scale * x_scale)
May I know do we have a way to just drop-in replace the F.linear for the custom kernel? Current bitnet.cpp only support LLM but not just this simple drop-in replacement Note that the activation is in 3d and weight is in 2d. but it seems like the kernel only support 2d matmul