BitNet icon indicating copy to clipboard operation
BitNet copied to clipboard

Drop-in replacement needed for f.linear?

Open tanvoontao opened this issue 7 months ago • 0 comments


class BitLinearInference(nn.Module):
    def __init__(self, 
        in_features: int,
        out_features: int,
    ):
        super().__init__()

        self.in_f = in_features
        self.out_f = out_features

        self.register_buffer("w", torch.empty((out_features, in_features)))
        self.register_buffer("w_scale", torch.empty((1,), dtype=torch.float32))

        self.norm = nn.RMSNorm(
            normalized_shape=in_features,
            eps=1e-5,
            elementwise_affine=True
        )

    def forward(self, x: Tensor) -> Tensor:
        x_norm = self.norm(x)
        x_int, x_scale = quantize_activation(x_norm)
        y_int = F.linear(x_int, self.w)
        y = y_int * (self.w_scale * x_scale)

May I know do we have a way to just drop-in replace the F.linear for the custom kernel? Current bitnet.cpp only support LLM but not just this simple drop-in replacement Note that the activation is in 3d and weight is in 2d. but it seems like the kernel only support 2d matmul

tanvoontao avatar Jul 03 '25 08:07 tanvoontao