haystack-core-integrations icon indicating copy to clipboard operation
haystack-core-integrations copied to clipboard

Gemini embedder models

Open motapinto opened this issue 10 months ago • 4 comments

Is your feature request related to a problem? Please describe. We are integrating Gemini embedder, given its high ranking in MTEB leaderboard, for our chatbot.

Describe the solution you'd like Would like to have a built in embedder in haystack sdk to support it

motapinto avatar Apr 02 '25 14:04 motapinto

The most logical place to add this would be to the google_ai integration in core integrations since it's using the same SDK. E.g. in this announcement we see some example code of how to call it:

from google import genai

client = genai.Client(api_key="GEMINI_API_KEY")

result = client.models.embed_content(
        model="gemini-embedding-exp-03-07",
        contents="How does alphafold work?",
)

print(result.embedding

sjrl avatar Apr 04 '25 07:04 sjrl

I implemented a custom component to do this for my work, here is the snippet if that helps:

import logging
from typing import List, Dict, Any, Optional, Union
from datetime import datetime

from haystack import component, Document
from haystack.utils.auth import Secret
from google import genai
from google.genai import types

from app.constants import DEFAULT_EMBEDDING_MODEL, GEMINI_API_KEY

logging.basicConfig(level=logging.WARNING)
logger = logging.getLogger(__name__)


@component
class GeminiEmbedder:
    """
    Embeds texts (single strings or Documents) using Google Gemini models.

    Automatically handles differentiation between query-like strings
    and document-like inputs based on the type passed to the `run` method.

    Usage example:
    ```python
    from haystack import Document
    from haystack.utils.auth import Secret

    embedder = GeminiEmbedder(api_key="your-api-key", model="text-embedding-004")

    # Embed a single query
    text = "What is the capital of France?"
    result_query = embedder.run(text)
    print(f"Query Embedding Snippet: {result_query['embedding'][:5]}...")
    # Output: {'embedding': [.....], 'documents': [], 'meta': {}}

    # Embed documents
    docs = [
        Document(content="Paris is the capital of France."),
        Document(content="Berlin is the capital of Germany.")
    ]
    result_docs = embedder.run(docs)
    print(f"Doc 0 Embedding Snippet: {result_docs['documents'][0].embedding[:5]}...")
    # Output: {'embedding': [], 'documents': [Document(...), Document(...)], 'meta': {}}
    ```
    """

    def __init__(self,
                 api_key: Secret = GEMINI_API_KEY,
                 model: str = DEFAULT_EMBEDDING_MODEL,
                 batch_size: int = 100):
        """
        Initializes the GeminiEmbedder component.

        :param model: Gemini embedding model name.
        :param api_key: Your Google API Key.
        :param batch_size: Number of texts to embed in a single API call when processing Documents (max 100 recommended).
        """
        self.api_key = api_key.resolve_value()
        self.model = model
        self.batch_size = batch_size

        # Log warning if batch size exceeds recommendation
        if self.batch_size > 100:
            logger.warning(f"Gemini API recommends a max batch size of 100. Provided: {self.batch_size}.")

        self.client = genai.Client(api_key=self.api_key)

    def _embed_batch(self, batch_texts: List[str], embed_config: types.EmbedContentConfig) -> List[Optional[List[float]]]:
        """Internal method to embed a batch of texts using the configured model and task type."""
        embeddings_result: List[Optional[List[float]]] = [None] * len(batch_texts)
        if not batch_texts:
            return embeddings_result

        try:
            # Make the API call using the globally configured client
            response = self.client.models.embed_content(
                model=self.model,
                contents=batch_texts,
                config=embed_config
            )
            return [_.values for _ in response.embeddings]

        except Exception as e:
            logger.error(f"Error embedding batch with Gemini (Task: {embed_config}): {e}", exc_info=True)
        return embeddings_result

    @component.output_types(embedding=List[float], documents=List[Document], meta=Dict[str, Any])
    def run(self, data: Union[str, List[Document]]):
        """
        Embeds a single query string or a list of documents.

        :param data: A string (for query embedding) or a List[Document] (for document embedding).
        :return: A dictionary containing:
                 - `embedding`: List[float] if input was str, else empty list.
                 - `documents`: List[Document] with updated embeddings if input was List[Document], else empty list.
                 - `meta`: An empty dictionary (Gemini API doesn't readily provide usage details like token count per call).
        """
        output_meta = {"model": str(self.model), 
                       "date": str(datetime.now().isoformat())}

        # === Handle single string input (interpreted as query) ===
        if isinstance(data, str):
            if not data: # Handle empty string case
                logger.warning("Received empty string for embedding.")
                return {"embedding": [], "documents": [], "meta": output_meta}

            embedding_result = self._embed_batch(
                batch_texts=[data],
                embed_config=types.EmbedContentConfig(task_type="RETRIEVAL_QUERY")
            )[0] # Get the first (and only) embedding

            return {"embedding": embedding_result or [], "documents": [], "meta": output_meta}

        # === Handle list of Document input ===
        elif isinstance(data, list):
            # Check if it's a list of Documents
            if not data or not all(isinstance(doc, Document) for doc in data):
                 raise TypeError(
                     "Input list must contain Haystack Document objects. Received non-Document elements."
                 )

            texts_to_embed = [(doc.meta.get("analysis_embedding_summary", "") + "\nContent: \n" + doc.content).strip() for doc in data if doc.content]
            all_embeddings: List[Optional[List[float]]] = []

            # Process in batches
            for i in range(0, len(texts_to_embed), self.batch_size):
                batch_texts = texts_to_embed[i : i + self.batch_size]
                batch_embeddings = self._embed_batch(
                    batch_texts=batch_texts,
                    embed_config=types.EmbedContentConfig(task_type="RETRIEVAL_DOCUMENT")
                )
                all_embeddings.extend(batch_embeddings)

            # Assign embeddings back to documents, handling potential mismatches
            if len(all_embeddings) != len(data):
                 logger.error(
                     f"Number of embeddings received ({len(all_embeddings)}) "
                     f"does not match number of documents ({len(data)}). "
                     f"Setting all document embeddings to None for this run."
                 )
                 for doc in data:
                     doc.embedding = None
            else:
                 for doc, emb in zip(data, all_embeddings):
                     doc.embedding = emb # Assigns the list or None if embedding failed for that item

            return {"embedding": [], "documents": data, "meta": output_meta}

It is quite barebones but if you'd like I can create a proper PR with proper implementation over the weekend.

AboveTheHeavens avatar Apr 04 '25 09:04 AboveTheHeavens

If you could do that @AboveTheHeavens that would be greatly appreciated!

This looks pretty far along, I'd only recommend checking out some of our other embedders to see how they are implemented to make sure we follow the same pattern. So things like:

  • Create separate GeminiTextEmbedder and GeminiDocumentEmbedder components
  • Add to_dict and from_dict methods to properly handle the deserialization and serialization of the api_key
  • Expose other parameters at init time that users might find useful. E.g. Our OpenAIDocumentEmbedder exposes these additional params.

sjrl avatar Apr 04 '25 11:04 sjrl

Sounds good, I'll raise a PR, following the implementation patterns.

AboveTheHeavens avatar Apr 04 '25 16:04 AboveTheHeavens