KBLaM
KBLaM copied to clipboard
transformers library compatibility
This classes are deprecated at transformers==4.46:
LlamaDynamicNTKScalingRotaryEmbedding
LlamaLinearScalingRotaryEmbedding
"`LlamaLinearScalingRotaryEmbedding` is deprecated an will be removed in v4.46. Please use "
However, KBLaM depends on transformers==4.48.0 and use that classes at KBLaM/src/kblam/models/llama3_model.py
I can reproduce the issue! Currently looking into it!
You can add this before importing KBLaMConfig in llama3_model.py file
class LlamaRotaryEmbedding(nn.Module): # Or whatever base class it inherits from
def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None): # Add max_position_embeddings
super().__init__()
self.dim = dim
self.max_position_embeddings = max_position_embeddings # Add this line
self.base = base
inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2).float().to(device) / dim))
self.register_buffer("inv_freq", inv_freq)
self._set_cos_sin_cache(seq_len=max_position_embeddings, device=device, dtype=torch.get_default_dtype())
def _set_cos_sin_cache(self, seq_len, device, dtype):
if seq_len is None:
raise ValueError("seq_len must be provided and cannot be None.")
if seq_len > self.max_position_embeddings:
raise ValueError(
f"seq_len ({seq_len}) exceeds max_position_embeddings ({self.max_position_embeddings}). "
"Consider increasing max_position_embeddings or using dynamic scaling."
)
self.max_seq_len_cached = seq_len
t = torch.arange(self.max_seq_len_cached, device=device, dtype=self.inv_freq.dtype)
freqs = torch.einsum("i,j->ij", t, self.inv_freq)
emb = torch.cat((freqs, freqs), dim=-1)
self.register_buffer("cos_cached", emb.cos().to(dtype), persistent=False)
self.register_buffer("sin_cached", emb.sin().to(dtype), persistent=False)
def forward(self, x, position_ids=None, seq_len: Optional[int] = None):
if seq_len is None:
seq_len = x.size(2) # Infer seq_len from input tensor
if seq_len > self.max_seq_len_cached:
self._set_cos_sin_cache(seq_len=seq_len, device=x.device, dtype=x.dtype)
cos = self.cos_cached[:seq_len].to(dtype=x.dtype)
sin = self.sin_cached[:seq_len].to(dtype=x.dtype)
# Debugging: Print tensor shapes
print(f"x.shape: {x.shape}, cos.shape: {cos.shape}, sin.shape: {sin.shape}")
return cos, sin
class LlamaDynamicNTKScalingRotaryEmbedding(LlamaRotaryEmbedding):
"""LlamaRotaryEmbedding extended with Dynamic NTK scaling. Credits to the Reddit users /u/bloc97 and /u/emozilla"""
def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None, scaling_factor=1.0):
self.scaling_factor = scaling_factor
super().__init__(dim, max_position_embeddings, base, device)
def _set_cos_sin_cache(self, seq_len, device, dtype):
self.max_seq_len_cached = seq_len
# Determine if we need to extend or create new position embeddings
if seq_len > self.max_position_embeddings:
# Dynamic scaling based on sequence length
base = self.base * ((self.scaling_factor * seq_len / self.max_position_embeddings) - (self.scaling_factor - 1)) ** (self.dim / (self.dim - 2))
inv_freq = 1.0 / (base ** (torch.arange(0, self.dim, 2).float().to(device) / self.dim))
self.register_buffer("inv_freq", inv_freq)
# Calculate the frequencies
t = torch.arange(self.max_seq_len_cached, device=device, dtype=self.inv_freq.dtype)
freqs = torch.einsum("i,j->ij", t, self.inv_freq)
# Different from paper, but it uses a different permutation to obtain the same calculation
emb = torch.cat((freqs, freqs), dim=-1)
self.register_buffer("cos_cached", emb.cos().to(dtype), persistent=False)
self.register_buffer("sin_cached", emb.sin().to(dtype), persistent=False)
class LlamaLinearScalingRotaryEmbedding(LlamaRotaryEmbedding):
"""LlamaRotaryEmbedding extended with linear scaling."""
def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None, scaling_factor=1.0):
self.scaling_factor = scaling_factor
super().__init__(dim, max_position_embeddings, base, device)
def _set_cos_sin_cache(self, seq_len, device, dtype):
self.max_seq_len_cached = seq_len
t = torch.arange(self.max_seq_len_cached, device=device, dtype=self.inv_freq.dtype)
t = t / self.scaling_factor
freqs = torch.einsum("i,j->ij", t, self.inv_freq)
emb = torch.cat((freqs, freqs), dim=-1)
self.register_buffer("cos_cached", emb.cos().to(dtype), persistent=False)
self.register_buffer("sin_cached", emb.sin().to(dtype), persistent=False)
Full repo at https://github.com/johnyquest7/KBLaM_mixed_precision Need further testing
Now pinned to 4.46.0 in #40. Does this solve these issues?
Yup! @ti250 , thank you