BitNet icon indicating copy to clipboard operation
BitNet copied to clipboard

CUDA kernel seems not optimized.

Open tanvoontao opened this issue 10 months ago • 0 comments

I'm using the kernel provided gemm_lowbit() to do inference for my model evaluation. But it seems like the inference speed abit too slow. I'm using this for my classification task.

BitNet b1 <-- take 5 hours Transformer (baseline) <-- take 1 hour

Pasted my BitLinearInference here.


def activation_norm_quant(x: Tensor):
    """ RMSNorm & Per-token quantization to 8 bits. It can be implemented as a fused kernel.
    Args:
        x: an activation tensor with shape [n, d]
    Returns:
        y: a quantized activation tensor with shape [n, d]
        scale: a scalar for dequantization with shape [1]
    """

    x = F.rms_norm(x, normalized_shape=[x.size(-1)])  
    
    scale = 127.0 / x.abs().max(dim=-1, keepdim=True).values.clamp_(min=1e-5)
    y = (x * scale).round().clamp_(-128, 127)
    return y, scale

gemm_lowbit_ext = load(
    name="gemm_lowbit_ext",
    sources=[
        os.path.join("kernel", "gemm_lowbit_kernel.cu")
    ],
    extra_cflags=["-O3", "-std=c++17"],
    extra_cuda_cflags=[
        "-O3",
        "-std=c++17",
        "--expt-relaxed-constexpr",
        "--use_fast_math",
    ],
    verbose=True,
)

class BitLinearB1Inference(nn.Linear):
    """
    Inference-only for BitNet b1 (1-bit). Weights should be quantized using 1-bit scheme.
    """
    def __init__(
        self, 
        in_features: int, 
        out_features: int, 
        bias = True,
        use1Bit = True,
        norm: Literal['layer', 'rms'] = 'rms'
    ):
        super().__init__(in_features, out_features, bias)
        self.use1Bit = use1Bit
        self.norm = get_norm(norm, in_features)

    def forward(self, x):
        # x: [B, S, d_model] => flatten => [M, K], where M = B*S, K = d_model
        bsz, seq_len, d_model = x.shape
        x_quant, _ = activation_norm_quant(x)
        x_2d = x_quant.half().view(bsz * seq_len, d_model)  # [M, K]
        
        # The kernel wants 'b' to be [K, N]. 
        # But PyTorch's weight is [out_features, in_features] => [N, K].
        # We need to transpose it to [K, N].
        w_half = self.weight.half().t()  # shape => [K, N] = [d_model, out_features]

        # Prepare output => [M, N] = [B*S, out_features]
        out_2d = torch.empty(
            (x_2d.shape[0], w_half.shape[1]),  # => [M, N]
            dtype=torch.half,
            device=x.device
        )

        # gemm_lowbit expects (M x K) * (K x N) -> (M x N)
        gemm_lowbit_ext.gemm_lowbit(x_2d, w_half, out_2d, 1.0, 1.0)

        # Reshape back to [B, S, out_features]
        out_3d = out_2d.view(bsz, seq_len, -1)
        return out_3d

tanvoontao avatar Apr 10 '25 02:04 tanvoontao