BitNet
BitNet copied to clipboard
CUDA kernel seems not optimized.
I'm using the kernel provided gemm_lowbit() to do inference for my model evaluation.
But it seems like the inference speed abit too slow.
I'm using this for my classification task.
BitNet b1 <-- take 5 hours Transformer (baseline) <-- take 1 hour
Pasted my BitLinearInference here.
def activation_norm_quant(x: Tensor):
""" RMSNorm & Per-token quantization to 8 bits. It can be implemented as a fused kernel.
Args:
x: an activation tensor with shape [n, d]
Returns:
y: a quantized activation tensor with shape [n, d]
scale: a scalar for dequantization with shape [1]
"""
x = F.rms_norm(x, normalized_shape=[x.size(-1)])
scale = 127.0 / x.abs().max(dim=-1, keepdim=True).values.clamp_(min=1e-5)
y = (x * scale).round().clamp_(-128, 127)
return y, scale
gemm_lowbit_ext = load(
name="gemm_lowbit_ext",
sources=[
os.path.join("kernel", "gemm_lowbit_kernel.cu")
],
extra_cflags=["-O3", "-std=c++17"],
extra_cuda_cflags=[
"-O3",
"-std=c++17",
"--expt-relaxed-constexpr",
"--use_fast_math",
],
verbose=True,
)
class BitLinearB1Inference(nn.Linear):
"""
Inference-only for BitNet b1 (1-bit). Weights should be quantized using 1-bit scheme.
"""
def __init__(
self,
in_features: int,
out_features: int,
bias = True,
use1Bit = True,
norm: Literal['layer', 'rms'] = 'rms'
):
super().__init__(in_features, out_features, bias)
self.use1Bit = use1Bit
self.norm = get_norm(norm, in_features)
def forward(self, x):
# x: [B, S, d_model] => flatten => [M, K], where M = B*S, K = d_model
bsz, seq_len, d_model = x.shape
x_quant, _ = activation_norm_quant(x)
x_2d = x_quant.half().view(bsz * seq_len, d_model) # [M, K]
# The kernel wants 'b' to be [K, N].
# But PyTorch's weight is [out_features, in_features] => [N, K].
# We need to transpose it to [K, N].
w_half = self.weight.half().t() # shape => [K, N] = [d_model, out_features]
# Prepare output => [M, N] = [B*S, out_features]
out_2d = torch.empty(
(x_2d.shape[0], w_half.shape[1]), # => [M, N]
dtype=torch.half,
device=x.device
)
# gemm_lowbit expects (M x K) * (K x N) -> (M x N)
gemm_lowbit_ext.gemm_lowbit(x_2d, w_half, out_2d, 1.0, 1.0)
# Reshape back to [B, S, out_features]
out_3d = out_2d.view(bsz, seq_len, -1)
return out_3d