Colbert Interaction correct?
Thank you very much for BGE-M3!
I am implementing something similar, i found a line in your code that puzzles me a bit:
https://github.com/FlagOpen/FlagEmbedding/blob/2225aacb54cf9e807aa116dfffeb0cceb291b38b/FlagEmbedding/finetune/embedder/encoder_only/m3/modeling.py#L227
might it be that the colbert interaction is incorrect?
the einsum includes the CLS token:
token_scores = torch.einsum('qin,pjn->qipj', q_reps, p_reps)
the scaling mask does not:
q_mask[:, 1:].sum(-1, keepdim=True)
in the limit n->\infy this works out correctly, for small sequence length this can become significant.
What do you think?
in my research i avoid letting the CLS token's embeddings interact, do you see better results with that?
good question , @approximated-intelligence can you share your new code?
@Sandy4321 my mistake, i had a look at the code again, the embedding and the mask are truncated in different functions, so that it works out:
def compute_colbert_score(self, q_reps, p_reps, q_mask: torch.Tensor=None):
"""Compute the colbert score.
Args:
q_reps (torch.Tensor): Query representations.
p_reps (torch.Tensor): Passage representations.
Returns:
torch.Tensor: The computed colber scores, adjusted by temperature.
"""
token_scores = torch.einsum('qin,pjn->qipj', q_reps, p_reps)
scores, _ = token_scores.max(-1)
scores = scores.sum(1) / q_mask[:, 1:].sum(-1, keepdim=True)
scores = scores / self.temperature
return scores
and:
def _colbert_embedding(self, last_hidden_state, mask):
"""Get the colbert vectors.
Args:
last_hidden_state (torch.Tensor): The model output's last hidden state.
attention_mask (torch.Tensor): Mask out padding tokens during pooling.
Returns:
torch.Tensor: The colbert vectors.
"""
colbert_vecs = self.colbert_linear(last_hidden_state[:, 1:])
colbert_vecs = colbert_vecs * mask[:, 1:][:, :, None].float()
return colbert_vecs
now of course i wonder what happens to the [SEP] token at the end of the input.
As you asked for some code: I wrote a very stripped down loading and inference code, to understand how bge-m3 works:
import torch
import torch.nn as nn
import torch.nn.functional as F
from transformers import XLMRobertaModel, XLMRobertaConfig, AutoTokenizer
from typing import Optional, Dict, Any
class XLMRobertaM3Embedder(XLMRobertaModel):
"""
XLMRoberta model with sparse embedding head for BGE-M3 style sparse retrieval.
Extends XLMRobertaForCausalLM to add sparse linear head functionality.
"""
def __init__(self, config: XLMRobertaConfig, sparse_dim: int = 1):
super().__init__(config)
# Add sparse linear head
self.sparse_linear = nn.Linear(config.hidden_size, sparse_dim)
# Optional: Add ColBERT head
self.colbert_linear = nn.Linear(config.hidden_size, config.hidden_size)
# Initialize new layers
self._init_sparse_weights()
def _init_sparse_weights(self):
"""Initialize sparse and colbert linear layers"""
nn.init.xavier_uniform_(self.sparse_linear.weight)
nn.init.constant_(self.sparse_linear.bias, 0)
nn.init.xavier_uniform_(self.colbert_linear.weight)
nn.init.constant_(self.colbert_linear.bias, 0)
def dense_embedding(self, hidden_states: torch.Tensor, attention_mask: torch.Tensor,
pooling: str = "cls") -> torch.Tensor:
"""Extract dense embeddings using specified pooling"""
if pooling == "cls":
return hidden_states[:, 0] # CLS token
elif pooling == "mean":
# Mean pooling with attention mask
masked_hidden = hidden_states * attention_mask.unsqueeze(-1).float()
summed = masked_hidden.sum(dim=1)
lengths = attention_mask.sum(dim=1, keepdim=True).float()
return summed / lengths
else:
raise ValueError(f"Unsupported pooling: {pooling}")
def sparse_embedding(self, hidden_states: torch.Tensor, input_ids: torch.Tensor,
attention_mask: torch.Tensor) -> torch.Tensor:
"""
Generate sparse embeddings in vocabulary space.
Maps token-level weights to vocabulary positions.
"""
# Apply sparse linear + ReLU
token_weights = F.relu(self.sparse_linear(hidden_states)) # [B, L, 1]
token_weights = token_weights.squeeze(-1) # [B, L]
# Mask padding tokens
token_weights = token_weights * attention_mask.float()
# Map to vocabulary space using scatter_reduce
batch_size = input_ids.size(0)
vocab_size = self.config.vocab_size
sparse_embedding = torch.zeros(
batch_size, vocab_size,
dtype=token_weights.dtype,
device=token_weights.device
)
# Use scatter_reduce with amax to handle repeated tokens
sparse_embedding = sparse_embedding.scatter_reduce(
dim=-1,
index=input_ids,
src=token_weights,
reduce="amax"
)
# Zero out special tokens
special_tokens = [
self.config.bos_token_id,
self.config.eos_token_id,
self.config.pad_token_id,
# self.config.unk_token_id # not available in the XLMRobertaConfig
]
for token_id in special_tokens:
if token_id is not None:
sparse_embedding[:, token_id] = 0.0
return sparse_embedding
def colbert_embedding(self, hidden_states: torch.Tensor,
attention_mask: torch.Tensor) -> torch.Tensor:
"""Generate ColBERT-style multi-vector embeddings (exclude CLS)"""
# Skip CLS token, apply linear transformation
colbert_vecs = self.colbert_linear(hidden_states[:, 1:]) # [B, L-1, D]
# Apply attention mask (excluding CLS position)
mask = attention_mask[:, 1:].unsqueeze(-1).float() # [B, L-1, 1]
colbert_vecs = colbert_vecs * mask
# Normalize
colbert_vecs = F.normalize(colbert_vecs, dim=-1)
return colbert_vecs
def forward(self, input_ids: torch.Tensor, attention_mask: torch.Tensor,
return_dense: bool = True, return_sparse: bool = False,
return_colbert: bool = False, pooling: str = "cls",
**kwargs) -> Dict[str, torch.Tensor]:
"""
Forward pass with multi-granularity embedding extraction
"""
# Get base model outputs
outputs = super().forward(
input_ids=input_ids,
attention_mask=attention_mask,
**kwargs
)
hidden_states = outputs.last_hidden_state # [B, L, H]
results = {}
if return_dense:
dense_emb = self.dense_embedding(hidden_states, attention_mask, pooling)
results['dense_embeddings'] = F.normalize(dense_emb, dim=-1)
if return_sparse:
sparse_emb = self.sparse_embedding(hidden_states, input_ids, attention_mask)
results['sparse_embeddings'] = sparse_emb
if return_colbert:
colbert_emb = self.colbert_embedding(hidden_states, attention_mask)
results['colbert_embeddings'] = colbert_emb
return results
@classmethod
def from_pretrained(cls, model_name_or_path: str, **kwargs):
"""
Load from BGE-M3 checkpoint and extract sparse/colbert heads
"""
# Load the base XLMRoberta class
model = super().from_pretrained(model_name_or_path, **kwargs)
# Try to load sparse and colbert heads if they exist
try:
import os
sparse_path = os.path.join(model_name_or_path, 'sparse_linear.pt')
colbert_path = os.path.join(model_name_or_path, 'colbert_linear.pt')
if os.path.exists(sparse_path):
sparse_state = torch.load(sparse_path, map_location='cpu')
model.sparse_linear.load_state_dict(sparse_state)
if os.path.exists(colbert_path):
colbert_state = torch.load(colbert_path, map_location='cpu')
model.colbert_linear.load_state_dict(colbert_state)
except Exception as e:
print(f"Warning: Could not load sparse/colbert heads: {e}")
print("Using randomly initialized heads")
return model
# Usage example
def main():
model = XLMRobertaM3Embedder.from_pretrained("models/bge-m3")
tokenizer = AutoTokenizer.from_pretrained("models/bge-m3")
# Example usage
texts = ["Hello world", "This is a test"]
inputs = tokenizer(texts, padding=True, truncation=True, return_tensors="pt")
print(inputs)
# Get all embedding types
with torch.no_grad():
results = model(
input_ids=inputs['input_ids'],
attention_mask=inputs['attention_mask'],
return_dense=True,
return_sparse=True,
return_colbert=True
)
print(f"Dense embeddings shape: {results['dense_embeddings'].shape}")
print(f"Sparse embeddings shape: {results['sparse_embeddings'].shape}")
print(f"ColBERT embeddings shape: {results['colbert_embeddings'].shape}")
# Access the raw XLMRoberta encoder directly - NO WRAPPER LAYERS!
encoder_layers = model.encoder.layer
print(f"Number of encoder layers: {len(encoder_layers)}")
# Direct access to any component
attention_layer = model.encoder.layer[0].attention
sparse_head = model.sparse_linear
colbert_head = model.colbert_linear
print(f"Sparse head: {sparse_head}")
if __name__ == "__main__":
main()