feat: `MultiModalRetriever`
Related Issue(s):
- Closes #2865
- Closes #2857
- Related to #2418
Proposed changes:
- Create a multi modal retriever by generalizing the concepts introduced by
TableTextRetriever - It introduces a stack of new subclasses to support such retriever, such as
MultiModalEmbedder) - Note that this Retriever will NOT be tested for working in pipelines, but only to work in isolation. It will also, most likely, stay undocumented. See #2418 for the rationale.
Additional context:
- As mentioned in the original issue, an attempt to generalize
TableTextRetrieverquickly proved too complex for the scope of this PR. - Rather than modifying an existing Retriever with the risk of breaking working code, I opted for cloning the class and its stack of supporting classes and perform the changes needed to support N models rather than just 3.
- A later goal is to be able to perform table retrieval with
MultiModalRetrieverand use its stack to dispose ofTriAdaptiveModel,BiAdaptiveModeland maybeAdaptiveModelitself, along with their respective helpers (custom predictive heads, custom processors, etc).
Additional changes:
- Soon I realized that with image support we need to generalize the concept of tokenizer. So I renamed
haystack/modeling/models/tokenization.py -> haystack/modeling/models/feature_extraction.py, created a class calledFeatureExtractorand used it as a uniform interface overAutoTokenizerandAutoFeatureExtractor
Pre-flight checklist
- [X] I have read the contributors guidelines
- [ ] If this is a code change, I added tests or updated existing ones
- [ ] If this is a code change, I updated the docstrings
Check out this pull request on ![]()
See visual diffs & provide feedback on Jupyter Notebooks.
Powered by ReviewNB
At this point MultiModalRetriever is confirmed to be able to do regular text2text retrieval.
Example comparison between MultiModalRetriever and EmbeddingRetriever on the same HF model
import logging
from haystack.nodes.retriever import EmbeddingRetriever
from haystack import Document
from haystack.document_stores import InMemoryDocumentStore
from haystack.nodes.retriever.multimodal import MultiModalRetriever
logging.basicConfig(format="%(levelname)s - %(name)s - %(message)s", level=logging.WARNING)
logging.getLogger("haystack").setLevel(level=logging.INFO)
docs = [
Document(content="My name is Christelle and I live in Paris"),
Document(content="My name is Camila and I don't live in Rome, but in Madrid"),
Document(content="My name is Matteo and I live in Rome, not in Madrid"),
Document(content="My name is Yoshiko and I live in Tokyo"),
Document(content="My name is Fatima and I live in Morocco"),
Document(content="My name is Lin and I live in Shanghai, but I lived in Rome before"),
Document(content="My name is John and I live in Sidney and I like Rome a lot"),
Document(content="My name is Tanay and I live in Delhi"),
Document(content="My name is Boris and I live in Omsk"),
Document(content="My name is Maria and I live in Maputo"),
]
docstore_mm = InMemoryDocumentStore()
docstore_mm.write_documents(docs)
docstore_emb = InMemoryDocumentStore()
docstore_emb.write_documents(docs)
retriever_mm = MultiModalRetriever(
document_store=docstore_mm,
query_embedding_model = "sentence-transformers/multi-qa-mpnet-base-dot-v1",
query_type="text",
passage_embedding_models = {"text": "sentence-transformers/multi-qa-mpnet-base-dot-v1"},
)
retriever_emb = EmbeddingRetriever(
document_store=docstore_emb,
embedding_model="sentence-transformers/multi-qa-mpnet-base-dot-v1",
)
docstore_mm.update_embeddings(retriever=retriever_mm)
docstore_emb.update_embeddings(retriever=retriever_emb)
results_mm = retriever_mm.retrieve("Who lives in Rome?", top_k=10)
results_mm = sorted(results_mm, key=lambda d: d.score, reverse=True)
results_emb = retriever_emb.retrieve("Who lives in Rome?", top_k=10)
results_emb = sorted(results_emb, key=lambda d: d.score, reverse=True)
print("\nMultiModalRetriever:")
for doc in results_mm:
print(doc.score, doc.content)
print("\nEmbeddingRetriever:")
for doc in results_emb:
print(doc.score, doc.content)
Output:
MultiModalRetriever:
0.5208672765153813 My name is John and I live in Sidney and I like Rome a lot
0.5191388809402967 My name is Lin and I live in Shanghai, but I lived in Rome before
0.5186486197556036 My name is Camila and I don't live in Rome, but in Madrid
0.5179478461149706 My name is Matteo and I live in Rome, not in Madrid
0.5142559657927063 My name is Boris and I live in Omsk
0.5139828892846635 My name is Maria and I live in Maputo
0.5128719158976963 My name is Christelle and I live in Paris
0.5123557966096112 My name is Fatima and I live in Morocco
0.5119127839282492 My name is Yoshiko and I live in Tokyo
0.5104531420951921 My name is Tanay and I live in Delhi
EmbeddingRetriever:
0.5565343848340962 My name is John and I live in Sidney and I like Rome a lot
0.5516977455360549 My name is Lin and I live in Shanghai, but I lived in Rome before
0.5477355154764227 My name is Camila and I don't live in Rome, but in Madrid
0.5471289177353567 My name is Matteo and I live in Rome, not in Madrid
0.5391855575892676 My name is Christelle and I live in Paris
0.5387820795590201 My name is Maria and I live in Maputo
0.5374046173004233 My name is Boris and I live in Omsk
0.5352793033810584 My name is Yoshiko and I live in Tokyo
0.5349538540865688 My name is Fatima and I live in Morocco
0.5307010966429587 My name is Tanay and I live in Delhi
The scores of the two Retrievers differ slightly, and so the ranking of the documents that are not relevant to the query, which I don't know what is it due to. Relevant documents seem to be sorted in the same way, so for now I consider this retriever to work fine on this task.
Currently MultiModalRetriever has no batching of any sort. Scoring it by runtime against EmbeddingRetriever shows it takes about twice as long on average.
Script
from datetime import datetime
import logging
from haystack.nodes.retriever import EmbeddingRetriever
from haystack import Document
from haystack.document_stores import InMemoryDocumentStore
from haystack.nodes.retriever.multimodal import MultiModalRetriever
logging.basicConfig(format="%(levelname)s - %(name)s - %(message)s", level=logging.WARNING)
logging.getLogger("haystack").setLevel(level=logging.INFO)
docs = [
Document(content="My name is Christelle and I live in Paris"),
Document(content="My name is Camila and I don't live in Rome, but in Madrid"),
Document(content="My name is Matteo and I live in Rome, not in Madrid"),
Document(content="My name is Yoshiko and I live in Tokyo"),
Document(content="My name is Fatima and I live in Morocco"),
Document(content="My name is Lin and I live in Shanghai, but I lived in Rome before"),
Document(content="My name is John and I live in Sidney and I like Rome a lot"),
Document(content="My name is Tanay and I live in Delhi"),
Document(content="My name is Boris and I live in Omsk"),
Document(content="My name is Maria and I live in Maputo"),
]
retrievers = {
"MultiModalRetriever": lambda docstore: MultiModalRetriever(
document_store=docstore,
query_embedding_model = "sentence-transformers/multi-qa-mpnet-base-dot-v1",
query_type="text",
passage_embedding_models = {"text": "sentence-transformers/multi-qa-mpnet-base-dot-v1"},
),
"EmbeddingRetriever": lambda docstore: EmbeddingRetriever(
document_store=docstore,
embedding_model="sentence-transformers/multi-qa-mpnet-base-dot-v1",
)
}
runtimes = {}
iterations = 100
for name, get_retriever in retrievers.items():
docstore = InMemoryDocumentStore()
docstore.write_documents(docs)
retriever = get_retriever(docstore)
docstore.update_embeddings(retriever=retriever)
start = datetime.now()
for _ in range(iterations):
results = retriever.retrieve("Who lives in Rome?", top_k=10)
stop = datetime.now()
results = sorted(results, key=lambda d: d.score, reverse=True)
runtimes[name] = f"Runtime: {(stop-start).seconds/iterations}s"
print(runtimes)
Results:
{'MultiModalRetriever': 'Runtime: 0.12s', 'EmbeddingRetriever': 'Runtime: 0.05s'}
Seems like MultiModalRetriever can be used for table retrieval as well.
Script comparing with EmbeddingRetriever
from datetime import datetime
import logging
import json
import pandas as pd
from haystack.nodes.retriever import EmbeddingRetriever
from haystack import Document
from haystack.document_stores import InMemoryDocumentStore
from haystack.utils import fetch_archive_from_http
from haystack.nodes.retriever.multimodal import MultiModalRetriever
logging.basicConfig(format="%(levelname)s - %(name)s - %(message)s", level=logging.WARNING)
logging.getLogger("haystack").setLevel(level=logging.INFO)
doc_dir = "data/tutorial15"
s3_url = "https://s3.eu-central-1.amazonaws.com/deepset.ai-farm-qa/datasets/documents/table_text_dataset.zip"
fetch_archive_from_http(url=s3_url, output_dir=doc_dir)
def read_tables(filename):
processed_tables = []
with open(filename) as tables:
tables = json.load(tables)
for key, table in tables.items():
current_columns = table["header"]
current_rows = table["data"]
current_df = pd.DataFrame(columns=current_columns, data=current_rows)
document = Document(content=current_df, content_type="table", id=key)
processed_tables.append(document)
return processed_tables
docs = read_tables(f"{doc_dir}/tables.json")
retrievers = {
"MultiModalRetriever": lambda docstore: MultiModalRetriever(
document_store=docstore,
query_embedding_model = "deepset/all-mpnet-base-v2-table",
passage_embedding_models = {
"table": "deepset/all-mpnet-base-v2-table"
},
batch_size=50
),
"EmbeddingRetriever": lambda docstore: EmbeddingRetriever(
document_store=docstore,
embedding_model="deepset/all-mpnet-base-v2-table", #"sentence-transformers/multi-qa-mpnet-base-dot-v1",
)
}
runtimes = {}
iterations = 1
for name, get_retriever in retrievers.items():
docstore = InMemoryDocumentStore()
docstore.write_documents(docs)
retriever = get_retriever(docstore)
docstore.update_embeddings(retriever=retriever)
start = datetime.now()
for _ in range(iterations):
results = retriever.retrieve("Who won the Super Bowl?", top_k=3)
stop = datetime.now()
results = sorted(results, key=lambda d: d.score, reverse=True)
runtimes[name] = f"Runtime: {(stop-start).seconds/iterations}s"
print("\nRESULTS:")
for doc in results:
print(" -> ", doc.score, "\n", doc.content)
print(runtimes)
Image Retrieval is now functional thanks to sentence-transformers.
# My images folder contains various pictures of animals from Wikipedia.
docs = [
Document(content=f"./examples/images/{filename}", content_type="image")
for filename in os.listdir("./examples/images")
]
# Mixed retrieval not working yet - postponed to another PR
# docs += [
# Document(content="A zebra is a horse-like animal that lives in Africa"),
# Document(content="A lion is an african big cat that feeds on zebras"),
# Document(content="Tuna is a large ocean fish.")
# ]
docstore = InMemoryDocumentStore(embedding_dim=512)
docstore.write_documents(docs)
retriever = MultiModalRetriever(
document_store=docstore,
query_embedding_model = "sentence-transformers/clip-ViT-B-32",
query_type="text",
passage_embedding_models = {
# "text": "sentence-transformers/clip-ViT-B-32", # Mixed retrieval not working yet - postponed to another PR
"image": "sentence-transformers/clip-ViT-B-32"
},
batch_size=100
)
docstore.update_embeddings(retriever=retriever)
results = retriever.retrieve(query, top_k=10)
results = sorted(results, key=lambda d: d.score, reverse=True)
print("\nRESULTS:")
for doc in results:
print(doc.score, doc.content.replace("./examples/images/", ""))
No modifications to the document stores, or to the primitives, were necessary.
There is a good chance that this Retriever might be able to work as it is in a regular Haystack pipeline (although that hasn't been tested yet).
Comparison with EmbeddingRetriever
Replacing the EmbeddingRetriever with a MultiModalRetriever in Tut15 works smoothly. Same goes for regular text-to-text retrieval (see examples in the comments), with the addition that leveraging sentence-transformers made the original speed difference disappear. Now the two retrievers are equally performant on both table and text retrieval.
TODO:
- Make sure the arch is consistent
- Tests for text-to-text, text-to-table and text-to-image
Hey @vblagoje, unless something else comes up with the tests, this PR should be stable and ready to review. Tests are being added but we can start to iterate on the architecture :blush:
@agnieszka-m Link to the only one review item left open (now hidden in a collapsibles' hell): https://github.com/deepset-ai/haystack/pull/2891#discussion_r993153149
Seems like MMRetriever works in Pipelines as well without issues. Example snippet for text-to-image retrieval:
import os
import logging
from haystack import Document, Pipeline
from haystack.document_stores import InMemoryDocumentStore
from haystack.nodes import MultiModalRetriever
docs = [
Document(content=f"./examples/images/{filename}", content_type="image")
for filename in os.listdir("./examples/images")
]
docstore_mm = InMemoryDocumentStore(embedding_dim=512)
docstore_mm.write_documents(docs)
retriever_mm = MultiModalRetriever(
document_store=docstore_mm,
query_embedding_model = "sentence-transformers/clip-ViT-B-32",
query_type="text",
document_embedding_models = {"image": "sentence-transformers/clip-ViT-B-32"}
)
docstore_mm.update_embeddings(retriever=retriever_mm)
pipeline = Pipeline()
pipeline.add_node(component=retriever_mm, name="retriever", inputs=["Query"])
results_mm = pipeline.run(query="An animal that lives in the mountains")
results_mm = sorted(results_mm["documents"], key=lambda d: d.score, reverse=True)
for doc in results_mm:
print(doc.score, doc.content)