BitNet icon indicating copy to clipboard operation
BitNet copied to clipboard

Question about weight quantization methodology memory savings

Open nnethercott opened this issue 1 year ago • 1 comments

Thanks for your quick implementation! I was reading through bitnet/bitbnet_b158.py and just had a short question.

In your implementation of quantize_weights you use the same procedure as outlined in the paper "The Era of 1-bit LLMs: All Large Language Models are in 1.58 Bits", but it looks like the quantized weights are stored in float32 while the activation quantization is explicitly casted to int8. I could be missing something, but how are you saving on memory (other than 8bit activations just like the paper) when the quantized weights are kept as float32s ?

   def quantize_weights(self, W):
        """
        Quantizes the weights using the absmean quantization function.

        Args:
            W (Tensor): The weight tensor to be quantized.

        Returns:
            Tensor: Quantized weight tensor.
        """
        gamma = torch.mean(torch.abs(W)) + self.eps
        W_scaled = W / gamma
        W_quantized = torch.sign(W_scaled) * torch.clamp(
            torch.abs(W_scaled).round(), max=1.0 # torch.float32 
        )
        return W_quantized

nnethercott avatar Mar 01 '24 10:03 nnethercott

Stale issue message

github-actions[bot] avatar Apr 30 '24 12:04 github-actions[bot]