Avoid missing packages and attn_mask dtype error
I installed the repo for Visualized-BGE following the instructions at FlagEmbedding/visual/README.md on CPU. I downloaded the weights from HF. When executing the example code in the README:
####### Use Visualized BGE doing multi-modal knowledge retrieval
import torch
from FlagEmbedding.visual.modeling import Visualized_BGE
model = Visualized_BGE(model_name_bge = "BAAI/bge-base-en-v1.5", model_weight="path: Visualized_base_en_v1.5.pth")
model.eval()
with torch.no_grad():
query_emb = model.encode(text="Are there sidewalks on both sides of the Mid-Hudson Bridge?")
candi_emb_1 = model.encode(text="The Mid-Hudson Bridge, spanning the Hudson River between Poughkeepsie and Highland.", image="./imgs/wiki_candi_1.jpg")
candi_emb_2 = model.encode(text="Golden_Gate_Bridge", image="./imgs/wiki_candi_2.jpg")
candi_emb_3 = model.encode(text="The Mid-Hudson Bridge was designated as a New York State Historic Civil Engineering Landmark by the American Society of Civil Engineers in 1983. The bridge was renamed the \"Franklin Delano Roosevelt Mid-Hudson Bridge\" in 1994.")
sim_1 = query_emb @ candi_emb_1.T
sim_2 = query_emb @ candi_emb_2.T
sim_3 = query_emb @ candi_emb_3.T
print(sim_1, sim_2, sim_3) # tensor([[0.6932]]) tensor([[0.4441]]) tensor([[0.6415]])
I got two errors:
1. Missing packages. peft and sentencepiece, the former for BAAI/bge-base-en-v1.5 and the latter for BAAI/bge-m3. I added those to setup.py. When pip installing them, all was well. Note: I have no experience with setup.py based installations, so best check if this is correct.
2. Dtype mismatch: happens when encoding only text, without images.
RuntimeError: Expected attn_mask dtype to be bool or float or to match query dtype, but got attn_mask.dtype: c10::Half and query.dtype: float instead.
This is solved by ensuring extended_attention_mask = extended_attention_mask.to(embedding_output.dtype), I added this in modeling.py:205. After that, all is well and the 3 numerical values of the similarities at the end of the above code snippet are reproduced.
Would be nice if you can merge this so I don't have to rely on my own fork for further work! Cheers