KBLaM icon indicating copy to clipboard operation
KBLaM copied to clipboard

transformers library compatibility

Open getStRiCtd opened this issue 10 months ago • 3 comments

This classes are deprecated at transformers==4.46:

LlamaDynamicNTKScalingRotaryEmbedding
LlamaLinearScalingRotaryEmbedding
"`LlamaLinearScalingRotaryEmbedding` is deprecated an will be removed in v4.46. Please use "

However, KBLaM depends on transformers==4.48.0 and use that classes at KBLaM/src/kblam/models/llama3_model.py

getStRiCtd avatar Mar 26 '25 10:03 getStRiCtd

I can reproduce the issue! Currently looking into it!

xidulu avatar Mar 26 '25 12:03 xidulu

You can add this before importing KBLaMConfig in llama3_model.py file


class LlamaRotaryEmbedding(nn.Module):  # Or whatever base class it inherits from
    def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None):  # Add max_position_embeddings
        super().__init__()
        self.dim = dim
        self.max_position_embeddings = max_position_embeddings # Add this line
        self.base = base
        inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2).float().to(device) / dim))
        self.register_buffer("inv_freq", inv_freq)
        self._set_cos_sin_cache(seq_len=max_position_embeddings, device=device, dtype=torch.get_default_dtype())

    def _set_cos_sin_cache(self, seq_len, device, dtype):
        if seq_len is None:
            raise ValueError("seq_len must be provided and cannot be None.")
        
        if seq_len > self.max_position_embeddings:
            raise ValueError(
                f"seq_len ({seq_len}) exceeds max_position_embeddings ({self.max_position_embeddings}). "
                "Consider increasing max_position_embeddings or using dynamic scaling."
            )
        
        self.max_seq_len_cached = seq_len
        t = torch.arange(self.max_seq_len_cached, device=device, dtype=self.inv_freq.dtype)
        freqs = torch.einsum("i,j->ij", t, self.inv_freq)
        emb = torch.cat((freqs, freqs), dim=-1)
        self.register_buffer("cos_cached", emb.cos().to(dtype), persistent=False)
        self.register_buffer("sin_cached", emb.sin().to(dtype), persistent=False)

    def forward(self, x, position_ids=None, seq_len: Optional[int] = None):
        if seq_len is None:
            seq_len = x.size(2)  # Infer seq_len from input tensor
        
        if seq_len > self.max_seq_len_cached:
            self._set_cos_sin_cache(seq_len=seq_len, device=x.device, dtype=x.dtype)
        
        cos = self.cos_cached[:seq_len].to(dtype=x.dtype)
        sin = self.sin_cached[:seq_len].to(dtype=x.dtype)
        
        # Debugging: Print tensor shapes
        print(f"x.shape: {x.shape}, cos.shape: {cos.shape}, sin.shape: {sin.shape}")
        
        return cos, sin

class LlamaDynamicNTKScalingRotaryEmbedding(LlamaRotaryEmbedding):
    """LlamaRotaryEmbedding extended with Dynamic NTK scaling. Credits to the Reddit users /u/bloc97 and /u/emozilla"""
    
    def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None, scaling_factor=1.0):
        self.scaling_factor = scaling_factor
        super().__init__(dim, max_position_embeddings, base, device)
        
    def _set_cos_sin_cache(self, seq_len, device, dtype):
        self.max_seq_len_cached = seq_len
        
        # Determine if we need to extend or create new position embeddings
        if seq_len > self.max_position_embeddings:
            # Dynamic scaling based on sequence length
            base = self.base * ((self.scaling_factor * seq_len / self.max_position_embeddings) - (self.scaling_factor - 1)) ** (self.dim / (self.dim - 2))
            inv_freq = 1.0 / (base ** (torch.arange(0, self.dim, 2).float().to(device) / self.dim))
            self.register_buffer("inv_freq", inv_freq)
            
        # Calculate the frequencies
        t = torch.arange(self.max_seq_len_cached, device=device, dtype=self.inv_freq.dtype)
        freqs = torch.einsum("i,j->ij", t, self.inv_freq)
        
        # Different from paper, but it uses a different permutation to obtain the same calculation
        emb = torch.cat((freqs, freqs), dim=-1)
        self.register_buffer("cos_cached", emb.cos().to(dtype), persistent=False)
        self.register_buffer("sin_cached", emb.sin().to(dtype), persistent=False)

class LlamaLinearScalingRotaryEmbedding(LlamaRotaryEmbedding):
    """LlamaRotaryEmbedding extended with linear scaling."""

    def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None, scaling_factor=1.0):
        self.scaling_factor = scaling_factor
        super().__init__(dim, max_position_embeddings, base, device)

    def _set_cos_sin_cache(self, seq_len, device, dtype):
        self.max_seq_len_cached = seq_len
        t = torch.arange(self.max_seq_len_cached, device=device, dtype=self.inv_freq.dtype)
        t = t / self.scaling_factor
        freqs = torch.einsum("i,j->ij", t, self.inv_freq)
        emb = torch.cat((freqs, freqs), dim=-1)
        self.register_buffer("cos_cached", emb.cos().to(dtype), persistent=False)
        self.register_buffer("sin_cached", emb.sin().to(dtype), persistent=False)

johnyquest7 avatar Mar 29 '25 04:03 johnyquest7

Full repo at https://github.com/johnyquest7/KBLaM_mixed_precision Need further testing

johnyquest7 avatar Mar 29 '25 05:03 johnyquest7

Now pinned to 4.46.0 in #40. Does this solve these issues?

ti250 avatar Apr 14 '25 12:04 ti250

Yup! @ti250 , thank you

getStRiCtd avatar Apr 14 '25 13:04 getStRiCtd