CLIP
CLIP copied to clipboard
[New Feature] Loading HuggingFace .safetensors and .bin variants for CLIP models
Issue
I tried to load HuggingFace safetensors/bin checkpoints into CLIP, but failed due to this error
---------------------------------------------------------------------------
RuntimeError Traceback (most recent call last)
File /opt/conda/lib/python3.11/site-packages/clip/clip.py:129, in load(name, device, jit, download_root)
127 try:
128 # loading JIT archive
--> 129 model = torch.jit.load(opened_file, map_location=device if jit else "cpu").eval()
130 state_dict = None
File /opt/conda/lib/python3.11/site-packages/torch/jit/_serialization.py:165, in load(f, map_location, _extra_files, _restore_shapes)
164 else:
--> 165 cpp_module = torch._C.import_ir_module_from_buffer(
166 cu, f.read(), map_location, _extra_files, _restore_shapes
167 ) # type: ignore[call-arg]
169 # TODO: Pretty sure this approach loses ConstSequential status and such
RuntimeError: PytorchStreamReader failed reading zip archive: failed finding central directory
Fix
- Load the .bin and .safetensors files in a different way
if model_path.endswith('.bin'):
state_dict = torch.load(model_path, map_location="cpu")
elif model_path.endswith('.safetensors'):
with safe_open(model_path, framework="pt", device="cpu") as f:
state_dict = {key: f.get_tensor(key) for key in f.keys()}
- But this is insufficient, because the HuggingFace and OpenAI have different naming conventions and storage format for the state_dict. An additional logic is added to convert the HuggingFace state_dict to the OpenAI state_dict format.
Tests
I tested the enhancements on
- https://huggingface.co/openai/clip-vit-base-patch16
- https://huggingface.co/openai/clip-vit-base-patch32
- https://huggingface.co/openai/clip-vit-large-patch14-336
- https://huggingface.co/openai/clip-vit-large-patch14
using the following snippet, checking that the state_dicts for the original model and the huggingface model are exactly the same. It passed for all 4 models.
import torch
import clip
from tqdm.auto import tqdm
official_model, _ = clip.load("ViT-B/16", device="cpu")
official_state_dict = official_model.state_dict()
dir = "<your local directory>"
try:
model, _ = clip.load(f"{dir}/pytorch_model.bin", device="cpu")
state_dict = model.state_dict()
assert len(state_dict) == len(official_state_dict)
for key in tqdm(official_state_dict, total=len(official_state_dict)):
assert key in state_dict
assert torch.equal(official_state_dict[key], state_dict[key])
model, _ = clip.load(f"{dir}/model.safetensors", device="cpu")
state_dict = model.state_dict()
for key in tqdm(official_state_dict, total=len(official_state_dict)):
assert key in state_dict
assert torch.equal(official_state_dict[key], state_dict[key])
except Exception as e:
print(e)