MllamaForCausalLM not returning past_key_values even with use_cache=True
System Info
-
transformersversion: 4.45.2 - Platform: Linux-5.4.0-187-generic-x86_64-with-glibc2.31
- Python version: 3.11.5
- Huggingface_hub version: 0.23.4
- Safetensors version: 0.4.3
- Accelerate version: 0.33.0
- Accelerate config: not found
- PyTorch version (GPU?): 2.1.2 (True)
- Tensorflow version (GPU?): not installed (NA)
- Flax version (CPU?/GPU?/TPU?): not installed (NA)
- Jax version: not installed
- JaxLib version: not installed
- Using distributed or parallel set-up in script?:
- Using GPU in script?:
- GPU type: NVIDIA RTX 6000 Ada Generation
Who can help?
@amyeroberts @ArthurZucker
Information
- [ ] The official example scripts
- [x] My own modified scripts
Tasks
- [ ] An officially supported task in the
examplesfolder (such as GLUE/SQuAD, ...) - [x] My own task or dataset (give details below)
Reproduction
from transformers import AutoModelForCausalLM, AutoTokenizer
import torch
import torch.nn as nn
import os
import matplotlib.pyplot as plt
from promptmix import generate_probmix
from transformers import MllamaForConditionalGeneration, AutoProcessor
model_id = "meta-llama/Llama-3.2-11B-Vision-Instruct" # https://huggingface.co/meta-llama/Llama-3.2-11B-Vision-Instruct
tokenizer = AutoTokenizer.from_pretrained(model_id)
model = AutoModelForCausalLM.from_pretrained(
model_id,
cache_dir=cache_directory,
torch_dtype=torch.float16,
device_map='auto',
)
processor = AutoProcessor.from_pretrained(model_id, cache_dir=cache_directory)
input_ids = tokenizer.encode('Hi, tell me a story of frog.', add_special_tokens=False, return_tensors='pt').to(model.device)
with torch.no_grad():
output = model.forward(input_ids=input_ids, use_cache=True)
output
Expected behavior
I expect to see a past_key_values in the output. However, I got None.
Hmm, actually we did Mllama quite similar to Idefics so the cache is not initialized by default when "use-cache=True". And yes, I think makes sense to init an empty cache if those are not special models like Gemma with special cache
Until the fix is there you can get pask-kv by passing model(**inputs, past_key_values=DynamicCache(), use_cache=True) but I see that the model weights will not be loaded proper way for CausalModel. In fact the ConditionalModel can deal with text-only input so for proper logits computation i'd recommend to use the ConditionalModel :)
With use_cache we should probably just init a default cache for the user, or we opt for forcing users to pass a cache object
Yes, exactly. I can make a PR for that
not stale, PR under progress