haystack icon indicating copy to clipboard operation
haystack copied to clipboard

feat: `MultiModalRetriever`

Open ZanSara opened this issue 3 years ago • 6 comments

Related Issue(s):

  • Closes #2865
  • Closes #2857
  • Related to #2418

Proposed changes:

  • Create a multi modal retriever by generalizing the concepts introduced by TableTextRetriever
  • It introduces a stack of new subclasses to support such retriever, such as MultiModalEmbedder)
  • Note that this Retriever will NOT be tested for working in pipelines, but only to work in isolation. It will also, most likely, stay undocumented. See #2418 for the rationale.

Additional context:

  • As mentioned in the original issue, an attempt to generalize TableTextRetriever quickly proved too complex for the scope of this PR.
  • Rather than modifying an existing Retriever with the risk of breaking working code, I opted for cloning the class and its stack of supporting classes and perform the changes needed to support N models rather than just 3.
  • A later goal is to be able to perform table retrieval with MultiModalRetriever and use its stack to dispose of TriAdaptiveModel, BiAdaptiveModel and maybe AdaptiveModel itself, along with their respective helpers (custom predictive heads, custom processors, etc).

Additional changes:

  • Soon I realized that with image support we need to generalize the concept of tokenizer. So I renamed haystack/modeling/models/tokenization.py -> haystack/modeling/models/feature_extraction.py, created a class called FeatureExtractor and used it as a uniform interface over AutoTokenizer and AutoFeatureExtractor

Pre-flight checklist

  • [X] I have read the contributors guidelines
  • [ ] If this is a code change, I added tests or updated existing ones
  • [ ] If this is a code change, I updated the docstrings

ZanSara avatar Jul 27 '22 09:07 ZanSara

Check out this pull request on  ReviewNB

See visual diffs & provide feedback on Jupyter Notebooks.


Powered by ReviewNB

At this point MultiModalRetriever is confirmed to be able to do regular text2text retrieval.

Example comparison between MultiModalRetriever and EmbeddingRetriever on the same HF model

import logging

from haystack.nodes.retriever import EmbeddingRetriever
from haystack import Document
from haystack.document_stores import InMemoryDocumentStore

from haystack.nodes.retriever.multimodal import MultiModalRetriever

logging.basicConfig(format="%(levelname)s - %(name)s -  %(message)s", level=logging.WARNING)
logging.getLogger("haystack").setLevel(level=logging.INFO)

docs = [
        Document(content="My name is Christelle and I live in Paris"),
        Document(content="My name is Camila and I don't live in Rome, but in Madrid"),
        Document(content="My name is Matteo and I live in Rome, not in Madrid"),
        Document(content="My name is Yoshiko and I live in Tokyo"),
        Document(content="My name is Fatima and I live in Morocco"),
        Document(content="My name is Lin and I live in Shanghai, but I lived in Rome before"),
        Document(content="My name is John and I live in Sidney and I like Rome a lot"),
        Document(content="My name is Tanay and I live in Delhi"),
        Document(content="My name is Boris and I live in Omsk"),
        Document(content="My name is Maria and I live in Maputo"),
    ]

docstore_mm = InMemoryDocumentStore()
docstore_mm.write_documents(docs)

docstore_emb = InMemoryDocumentStore()
docstore_emb.write_documents(docs)

retriever_mm = MultiModalRetriever(
    document_store=docstore_mm,
    query_embedding_model = "sentence-transformers/multi-qa-mpnet-base-dot-v1",
    query_type="text",
    passage_embedding_models = {"text": "sentence-transformers/multi-qa-mpnet-base-dot-v1"},
)
retriever_emb = EmbeddingRetriever(
    document_store=docstore_emb,
    embedding_model="sentence-transformers/multi-qa-mpnet-base-dot-v1",
)
docstore_mm.update_embeddings(retriever=retriever_mm)
docstore_emb.update_embeddings(retriever=retriever_emb)


results_mm = retriever_mm.retrieve("Who lives in Rome?", top_k=10)
results_mm = sorted(results_mm, key=lambda d: d.score, reverse=True)

results_emb = retriever_emb.retrieve("Who lives in Rome?", top_k=10)
results_emb = sorted(results_emb, key=lambda d: d.score, reverse=True)

print("\nMultiModalRetriever:")
for doc in results_mm:
    print(doc.score, doc.content)
print("\nEmbeddingRetriever:")
for doc in results_emb:
    print(doc.score, doc.content)

Output:

MultiModalRetriever:
0.5208672765153813 My name is John and I live in Sidney and I like Rome a lot
0.5191388809402967 My name is Lin and I live in Shanghai, but I lived in Rome before
0.5186486197556036 My name is Camila and I don't live in Rome, but in Madrid
0.5179478461149706 My name is Matteo and I live in Rome, not in Madrid
0.5142559657927063 My name is Boris and I live in Omsk
0.5139828892846635 My name is Maria and I live in Maputo
0.5128719158976963 My name is Christelle and I live in Paris
0.5123557966096112 My name is Fatima and I live in Morocco
0.5119127839282492 My name is Yoshiko and I live in Tokyo
0.5104531420951921 My name is Tanay and I live in Delhi

EmbeddingRetriever:
0.5565343848340962 My name is John and I live in Sidney and I like Rome a lot
0.5516977455360549 My name is Lin and I live in Shanghai, but I lived in Rome before
0.5477355154764227 My name is Camila and I don't live in Rome, but in Madrid
0.5471289177353567 My name is Matteo and I live in Rome, not in Madrid
0.5391855575892676 My name is Christelle and I live in Paris
0.5387820795590201 My name is Maria and I live in Maputo
0.5374046173004233 My name is Boris and I live in Omsk
0.5352793033810584 My name is Yoshiko and I live in Tokyo
0.5349538540865688 My name is Fatima and I live in Morocco
0.5307010966429587 My name is Tanay and I live in Delhi

The scores of the two Retrievers differ slightly, and so the ranking of the documents that are not relevant to the query, which I don't know what is it due to. Relevant documents seem to be sorted in the same way, so for now I consider this retriever to work fine on this task.

ZanSara avatar Aug 24 '22 09:08 ZanSara

Currently MultiModalRetriever has no batching of any sort. Scoring it by runtime against EmbeddingRetriever shows it takes about twice as long on average.

Script
from datetime import datetime
import logging

from haystack.nodes.retriever import EmbeddingRetriever
from haystack import Document
from haystack.document_stores import InMemoryDocumentStore

from haystack.nodes.retriever.multimodal import MultiModalRetriever


logging.basicConfig(format="%(levelname)s - %(name)s -  %(message)s", level=logging.WARNING)
logging.getLogger("haystack").setLevel(level=logging.INFO)


docs = [
        Document(content="My name is Christelle and I live in Paris"),
        Document(content="My name is Camila and I don't live in Rome, but in Madrid"),
        Document(content="My name is Matteo and I live in Rome, not in Madrid"),
        Document(content="My name is Yoshiko and I live in Tokyo"),
        Document(content="My name is Fatima and I live in Morocco"),
        Document(content="My name is Lin and I live in Shanghai, but I lived in Rome before"),
        Document(content="My name is John and I live in Sidney and I like Rome a lot"),
        Document(content="My name is Tanay and I live in Delhi"),
        Document(content="My name is Boris and I live in Omsk"),
        Document(content="My name is Maria and I live in Maputo"),
    ]

retrievers = {
    "MultiModalRetriever": lambda docstore: MultiModalRetriever(
        document_store=docstore,
        query_embedding_model = "sentence-transformers/multi-qa-mpnet-base-dot-v1",
        query_type="text",
        passage_embedding_models = {"text": "sentence-transformers/multi-qa-mpnet-base-dot-v1"},
    ),
    "EmbeddingRetriever": lambda docstore: EmbeddingRetriever(
        document_store=docstore,
        embedding_model="sentence-transformers/multi-qa-mpnet-base-dot-v1",
    )
}

runtimes = {}
iterations = 100
for name, get_retriever in retrievers.items():
        
    docstore = InMemoryDocumentStore()
    docstore.write_documents(docs)

    retriever = get_retriever(docstore)
    docstore.update_embeddings(retriever=retriever)

    start = datetime.now()
    for _ in range(iterations):
        results = retriever.retrieve("Who lives in Rome?", top_k=10)
    stop = datetime.now()
    
    results = sorted(results, key=lambda d: d.score, reverse=True)

    runtimes[name] = f"Runtime: {(stop-start).seconds/iterations}s"

print(runtimes)

Results:

{'MultiModalRetriever': 'Runtime: 0.12s', 'EmbeddingRetriever': 'Runtime: 0.05s'}

ZanSara avatar Aug 24 '22 09:08 ZanSara

Seems like MultiModalRetriever can be used for table retrieval as well.

Script comparing with EmbeddingRetriever
from datetime import datetime
import logging
import json

import pandas as pd

from haystack.nodes.retriever import EmbeddingRetriever
from haystack import Document
from haystack.document_stores import InMemoryDocumentStore
from haystack.utils import fetch_archive_from_http

from haystack.nodes.retriever.multimodal import MultiModalRetriever

logging.basicConfig(format="%(levelname)s - %(name)s -  %(message)s", level=logging.WARNING)
logging.getLogger("haystack").setLevel(level=logging.INFO)

doc_dir = "data/tutorial15"
s3_url = "https://s3.eu-central-1.amazonaws.com/deepset.ai-farm-qa/datasets/documents/table_text_dataset.zip"
fetch_archive_from_http(url=s3_url, output_dir=doc_dir)

def read_tables(filename):
    processed_tables = []
    with open(filename) as tables:
        tables = json.load(tables)
        for key, table in tables.items():
            current_columns = table["header"]
            current_rows = table["data"]
            current_df = pd.DataFrame(columns=current_columns, data=current_rows)
            document = Document(content=current_df, content_type="table", id=key)
            processed_tables.append(document)

    return processed_tables
docs = read_tables(f"{doc_dir}/tables.json")

retrievers = {
    "MultiModalRetriever": lambda docstore: MultiModalRetriever(
        document_store=docstore,
        query_embedding_model = "deepset/all-mpnet-base-v2-table",
        passage_embedding_models = {
            "table": "deepset/all-mpnet-base-v2-table"
        },
        batch_size=50
    ),
    "EmbeddingRetriever": lambda docstore: EmbeddingRetriever(
        document_store=docstore,
        embedding_model="deepset/all-mpnet-base-v2-table", #"sentence-transformers/multi-qa-mpnet-base-dot-v1",
    )
}

runtimes = {}
iterations = 1
for name, get_retriever in retrievers.items():
        
    docstore = InMemoryDocumentStore()
    docstore.write_documents(docs)

    retriever = get_retriever(docstore)
    docstore.update_embeddings(retriever=retriever)

    start = datetime.now()
    for _ in range(iterations):
        results = retriever.retrieve("Who won the Super Bowl?", top_k=3)
    stop = datetime.now()
    
    results = sorted(results, key=lambda d: d.score, reverse=True)

    runtimes[name] = f"Runtime: {(stop-start).seconds/iterations}s"

    print("\nRESULTS:")
    for doc in results:
        print(" -> ", doc.score, "\n", doc.content)

print(runtimes)

ZanSara avatar Aug 24 '22 11:08 ZanSara

Image Retrieval is now functional thanks to sentence-transformers.

# My images folder contains various pictures of animals from Wikipedia.
docs = [
    Document(content=f"./examples/images/{filename}", content_type="image")
    for filename in os.listdir("./examples/images")
]
# Mixed retrieval not working yet - postponed to another PR
# docs += [
#     Document(content="A zebra is a horse-like animal that lives in Africa"),
#     Document(content="A lion is an african big cat that feeds on zebras"),
#     Document(content="Tuna is a large ocean fish.")
# ]

docstore = InMemoryDocumentStore(embedding_dim=512)
docstore.write_documents(docs)

retriever = MultiModalRetriever(
    document_store=docstore,
    query_embedding_model = "sentence-transformers/clip-ViT-B-32",
    query_type="text",
    passage_embedding_models = {
        # "text": "sentence-transformers/clip-ViT-B-32",  # Mixed retrieval not working yet  - postponed to another PR
        "image": "sentence-transformers/clip-ViT-B-32"
    },
    batch_size=100
)
docstore.update_embeddings(retriever=retriever)

results = retriever.retrieve(query, top_k=10)

results = sorted(results, key=lambda d: d.score, reverse=True)

print("\nRESULTS:")
for doc in results:
    print(doc.score, doc.content.replace("./examples/images/", ""))

No modifications to the document stores, or to the primitives, were necessary.

There is a good chance that this Retriever might be able to work as it is in a regular Haystack pipeline (although that hasn't been tested yet).

Comparison with EmbeddingRetriever

Replacing the EmbeddingRetriever with a MultiModalRetriever in Tut15 works smoothly. Same goes for regular text-to-text retrieval (see examples in the comments), with the addition that leveraging sentence-transformers made the original speed difference disappear. Now the two retrievers are equally performant on both table and text retrieval.

TODO:

  • Make sure the arch is consistent
  • Tests for text-to-text, text-to-table and text-to-image

ZanSara avatar Sep 06 '22 09:09 ZanSara

Hey @vblagoje, unless something else comes up with the tests, this PR should be stable and ready to review. Tests are being added but we can start to iterate on the architecture :blush:

ZanSara avatar Sep 14 '22 16:09 ZanSara

@agnieszka-m Link to the only one review item left open (now hidden in a collapsibles' hell): https://github.com/deepset-ai/haystack/pull/2891#discussion_r993153149

ZanSara avatar Oct 12 '22 16:10 ZanSara

Seems like MMRetriever works in Pipelines as well without issues. Example snippet for text-to-image retrieval:

import os
import logging

from haystack import Document, Pipeline
from haystack.document_stores import InMemoryDocumentStore
from haystack.nodes import MultiModalRetriever

docs = [
    Document(content=f"./examples/images/{filename}", content_type="image")
    for filename in os.listdir("./examples/images")
]

docstore_mm = InMemoryDocumentStore(embedding_dim=512)
docstore_mm.write_documents(docs)

retriever_mm = MultiModalRetriever(
    document_store=docstore_mm,
    query_embedding_model = "sentence-transformers/clip-ViT-B-32",
    query_type="text",
    document_embedding_models = {"image": "sentence-transformers/clip-ViT-B-32"}
)

docstore_mm.update_embeddings(retriever=retriever_mm)

pipeline = Pipeline()
pipeline.add_node(component=retriever_mm, name="retriever", inputs=["Query"])

results_mm = pipeline.run(query="An animal that lives in the mountains")

results_mm = sorted(results_mm["documents"], key=lambda d: d.score, reverse=True)
for doc in results_mm:
    print(doc.score, doc.content)

ZanSara avatar Oct 13 '22 13:10 ZanSara