FlagEmbedding icon indicating copy to clipboard operation
FlagEmbedding copied to clipboard

Colbert Interaction correct?

Open approximated-intelligence opened this issue 5 months ago • 4 comments

Thank you very much for BGE-M3!

I am implementing something similar, i found a line in your code that puzzles me a bit:

https://github.com/FlagOpen/FlagEmbedding/blob/2225aacb54cf9e807aa116dfffeb0cceb291b38b/FlagEmbedding/finetune/embedder/encoder_only/m3/modeling.py#L227

might it be that the colbert interaction is incorrect?

the einsum includes the CLS token:

token_scores = torch.einsum('qin,pjn->qipj', q_reps, p_reps)

the scaling mask does not:

q_mask[:, 1:].sum(-1, keepdim=True)

in the limit n->\infy this works out correctly, for small sequence length this can become significant.

What do you think?

in my research i avoid letting the CLS token's embeddings interact, do you see better results with that?

good question , @approximated-intelligence can you share your new code?

Sandy4321 avatar Aug 31 '25 17:08 Sandy4321

@Sandy4321 my mistake, i had a look at the code again, the embedding and the mask are truncated in different functions, so that it works out:

    def compute_colbert_score(self, q_reps, p_reps, q_mask: torch.Tensor=None):
        """Compute the colbert score.

        Args:
            q_reps (torch.Tensor): Query representations.
            p_reps (torch.Tensor): Passage representations.

        Returns:
            torch.Tensor: The computed colber scores, adjusted by temperature.
        """
        token_scores = torch.einsum('qin,pjn->qipj', q_reps, p_reps)
        scores, _ = token_scores.max(-1)
        scores = scores.sum(1) / q_mask[:, 1:].sum(-1, keepdim=True)
        scores = scores / self.temperature
        return scores

and:

    def _colbert_embedding(self, last_hidden_state, mask):
        """Get the colbert vectors.

        Args:
            last_hidden_state (torch.Tensor): The model output's last hidden state.
            attention_mask (torch.Tensor): Mask out padding tokens during pooling.

        Returns:
            torch.Tensor: The colbert vectors.
        """
        colbert_vecs = self.colbert_linear(last_hidden_state[:, 1:])
        colbert_vecs = colbert_vecs * mask[:, 1:][:, :, None].float()
        return colbert_vecs

now of course i wonder what happens to the [SEP] token at the end of the input.

As you asked for some code: I wrote a very stripped down loading and inference code, to understand how bge-m3 works:

import torch
import torch.nn as nn
import torch.nn.functional as F
from transformers import XLMRobertaModel, XLMRobertaConfig, AutoTokenizer
from typing import Optional, Dict, Any


class XLMRobertaM3Embedder(XLMRobertaModel):
    """
    XLMRoberta model with sparse embedding head for BGE-M3 style sparse retrieval.
    Extends XLMRobertaForCausalLM to add sparse linear head functionality.
    """
    
    def __init__(self, config: XLMRobertaConfig, sparse_dim: int = 1):
        super().__init__(config)
        
        # Add sparse linear head
        self.sparse_linear = nn.Linear(config.hidden_size, sparse_dim)
        
        # Optional: Add ColBERT head 
        self.colbert_linear = nn.Linear(config.hidden_size, config.hidden_size)
        
        # Initialize new layers
        self._init_sparse_weights()
        
    def _init_sparse_weights(self):
        """Initialize sparse and colbert linear layers"""
        nn.init.xavier_uniform_(self.sparse_linear.weight)
        nn.init.constant_(self.sparse_linear.bias, 0)
        
        nn.init.xavier_uniform_(self.colbert_linear.weight) 
        nn.init.constant_(self.colbert_linear.bias, 0)
    
    def dense_embedding(self, hidden_states: torch.Tensor, attention_mask: torch.Tensor, 
                       pooling: str = "cls") -> torch.Tensor:
        """Extract dense embeddings using specified pooling"""
        if pooling == "cls":
            return hidden_states[:, 0]  # CLS token
        elif pooling == "mean":
            # Mean pooling with attention mask
            masked_hidden = hidden_states * attention_mask.unsqueeze(-1).float()
            summed = masked_hidden.sum(dim=1)
            lengths = attention_mask.sum(dim=1, keepdim=True).float()
            return summed / lengths
        else:
            raise ValueError(f"Unsupported pooling: {pooling}")
    
    def sparse_embedding(self, hidden_states: torch.Tensor, input_ids: torch.Tensor,
                        attention_mask: torch.Tensor) -> torch.Tensor:
        """
        Generate sparse embeddings in vocabulary space.
        Maps token-level weights to vocabulary positions.
        """
        # Apply sparse linear + ReLU
        token_weights = F.relu(self.sparse_linear(hidden_states))  # [B, L, 1]
        token_weights = token_weights.squeeze(-1)  # [B, L]
        
        # Mask padding tokens
        token_weights = token_weights * attention_mask.float()
        
        # Map to vocabulary space using scatter_reduce
        batch_size = input_ids.size(0)
        vocab_size = self.config.vocab_size
        
        sparse_embedding = torch.zeros(
            batch_size, vocab_size,
            dtype=token_weights.dtype,
            device=token_weights.device
        )
        
        # Use scatter_reduce with amax to handle repeated tokens
        sparse_embedding = sparse_embedding.scatter_reduce(
            dim=-1,
            index=input_ids,
            src=token_weights, 
            reduce="amax"
        )
        
        # Zero out special tokens
        special_tokens = [
            self.config.bos_token_id,
            self.config.eos_token_id, 
            self.config.pad_token_id,
            # self.config.unk_token_id # not available in the XLMRobertaConfig
        ]
        for token_id in special_tokens:
            if token_id is not None:
                sparse_embedding[:, token_id] = 0.0
                
        return sparse_embedding
    
    def colbert_embedding(self, hidden_states: torch.Tensor, 
                         attention_mask: torch.Tensor) -> torch.Tensor:
        """Generate ColBERT-style multi-vector embeddings (exclude CLS)"""
        # Skip CLS token, apply linear transformation
        colbert_vecs = self.colbert_linear(hidden_states[:, 1:])  # [B, L-1, D]
        
        # Apply attention mask (excluding CLS position)
        mask = attention_mask[:, 1:].unsqueeze(-1).float()  # [B, L-1, 1]
        colbert_vecs = colbert_vecs * mask
        
        # Normalize
        colbert_vecs = F.normalize(colbert_vecs, dim=-1)
        
        return colbert_vecs
    
    def forward(self, input_ids: torch.Tensor, attention_mask: torch.Tensor,
                return_dense: bool = True, return_sparse: bool = False, 
                return_colbert: bool = False, pooling: str = "cls",
                **kwargs) -> Dict[str, torch.Tensor]:
        """
        Forward pass with multi-granularity embedding extraction
        """
        # Get base model outputs
        outputs = super().forward(
            input_ids=input_ids,
            attention_mask=attention_mask,
            **kwargs
        )
        
        hidden_states = outputs.last_hidden_state  # [B, L, H]
        
        results = {}
        
        if return_dense:
            dense_emb = self.dense_embedding(hidden_states, attention_mask, pooling)
            results['dense_embeddings'] = F.normalize(dense_emb, dim=-1)
            
        if return_sparse:
            sparse_emb = self.sparse_embedding(hidden_states, input_ids, attention_mask)
            results['sparse_embeddings'] = sparse_emb
            
        if return_colbert:
            colbert_emb = self.colbert_embedding(hidden_states, attention_mask)
            results['colbert_embeddings'] = colbert_emb
            
        return results
    
    @classmethod
    def from_pretrained(cls, model_name_or_path: str, **kwargs):
        """
        Load from BGE-M3 checkpoint and extract sparse/colbert heads
        """
        # Load the base XLMRoberta class
        model = super().from_pretrained(model_name_or_path, **kwargs)
        
        # Try to load sparse and colbert heads if they exist
        try:
            import os
            sparse_path = os.path.join(model_name_or_path, 'sparse_linear.pt')
            colbert_path = os.path.join(model_name_or_path, 'colbert_linear.pt')
            
            if os.path.exists(sparse_path):
                sparse_state = torch.load(sparse_path, map_location='cpu')
                model.sparse_linear.load_state_dict(sparse_state)
                
            if os.path.exists(colbert_path):
                colbert_state = torch.load(colbert_path, map_location='cpu') 
                model.colbert_linear.load_state_dict(colbert_state)
                
        except Exception as e:
            print(f"Warning: Could not load sparse/colbert heads: {e}")
            print("Using randomly initialized heads")
            
        return model


# Usage example
def main():
    model = XLMRobertaM3Embedder.from_pretrained("models/bge-m3")
    tokenizer = AutoTokenizer.from_pretrained("models/bge-m3")
        
    # Example usage
    texts = ["Hello world", "This is a test"]
    inputs = tokenizer(texts, padding=True, truncation=True, return_tensors="pt")

    print(inputs)

    # Get all embedding types
    with torch.no_grad():
        results = model(
            input_ids=inputs['input_ids'],
            attention_mask=inputs['attention_mask'],
            return_dense=True,
            return_sparse=True, 
            return_colbert=True
        )
    
    print(f"Dense embeddings shape: {results['dense_embeddings'].shape}")
    print(f"Sparse embeddings shape: {results['sparse_embeddings'].shape}")  
    print(f"ColBERT embeddings shape: {results['colbert_embeddings'].shape}")
    
    # Access the raw XLMRoberta encoder directly - NO WRAPPER LAYERS!
    encoder_layers = model.encoder.layer
    print(f"Number of encoder layers: {len(encoder_layers)}")
    
    # Direct access to any component
    attention_layer = model.encoder.layer[0].attention
    sparse_head = model.sparse_linear
    colbert_head = model.colbert_linear
    
    print(f"Sparse head: {sparse_head}")


if __name__ == "__main__":
    main()