[Feature] Utilize the Internvl 2.0 for image-text retrieval
Motivation
thanks for your excellent work.
I have found that in the internvl-g, we can find the retrieval code, which can be found also in clip benchmark.
I wonder that How can we utilize the Internvl 2.0 (or Internvl-chat) codebase for image-text retrieval tasks, which also can benefits for the further training.
Related resources
No response
Additional context
No response
The main develop difficulties are that I can't find the code for transferring the embedding into the clip embedding like:
self.text_projection = nn.Parameter(torch.empty(transformer_width, clip_embed_dim))
https://github.com/OpenGVLab/InternVL/blob/main/clip_benchmark/clip_benchmark/models/internvl_c_pytorch/internvl_c.py#L348
which can not transfer the text embedding into the clip embedding.
Can you help me solve this problem?
I decide to retrain the clip projection layer in the internvl-chat like clip_project or text_projection from internvl-g. However, when I perform the contrastive learning on the image embedding and want to gather all tensor from all gpus by utilizing the GatherLayer, it stops training.
image_itc_all = GatherLayer.apply(image_itc).flatten(0, 1)
backbone_embeds_all = GatherLayer.apply(backbone_embeds).flatten(0, 1)
text_itc_all = GatherLayer.apply(text_itc).flatten(0, 1)
image_ids_all = GatherLayer.apply(image_ids).flatten(0, 1)
Ask for sincerely help!
I decide to retrain the clip projection layer like
clip_projectortext_projectionfrominternvl-g. However, when I perform the contrastive learning on the image embedding and want to gather all tensor from all gpus by utilizing theGatherLayer, it freezes and stops training.image_itc_all = GatherLayer.apply(image_itc).flatten(0, 1) backbone_embeds_all = GatherLayer.apply(backbone_embeds).flatten(0, 1) text_itc_all = GatherLayer.apply(text_itc).flatten(0, 1) image_ids_all = GatherLayer.apply(image_ids).flatten(0, 1)Ask for sincerely help!
I also have tried other ways, just like, which is still not working.
image_itc_all = torch.distributed.nn.all_gather(image_itc)
backbone_embeds_all = torch.distributed.nn.all_gather(backbone_embeds)
text_itc_all = torch.distributed.nn.all_gather(text_itc)
image_ids_all = torch.distributed.nn.all_gather(image_ids)