Enable caching for 'generate' and 'stream_generate' functions to ensure persistence of cache across multiple requests
- Add two new data classes called CacheHistory and StepOutput for storing cache history along with the token history
- Add the option to return cache in "generate" and "stream_generate" for further cache reuse.
- Add two functions to save and load cache from disk.
- "prompt" argument in the "generate" and "stream_generate" is no longer a suffix for the cache history. It will be the full prompt. In "generate_step", there is a check to find out the index of the maximum shared prefix between the list of token ids from the new prompt and the token ids from the history prompt.
Usage
from mlx_lm import load, stream_generate
from mlx_lm.utils import save_cache, load_cache
model, tokenizer = load('/Path/to/model')
prompt = 'Your long prompt here...'
# First generation without prompt cache history
for i, cache in stream_generate(model=model,
tokenizer=tokenizer, prompt=prompt, max_tokens=100, return_cache=True, verbose=True):
print(i, end='')
# Processing prompt (1431/1431): 100%|██████████| 3/3 [00:02<00:00, 1.50it/s]
# Prompt preprocessing time for 1431 tokens: 2.007s (713.1801 tok/sec)
# Second generation with prompt cache history
for i, cache in stream_generate(model=model,
tokenizer=tokenizer, prompt=prompt, max_tokens=100, return_cache=True, verbose=True, cache_history=cache):
new += i
print(i, end='')
# Processing prompt (1/1): 100%|██████████| 1/1 [00:00<00:00, 595.61it/s]
# Prompt preprocessing time for 1 tokens: 0.001921s (520.6299 tok/sec)
# Save the cache history to use later
save_cache(cache, filename='test.safetensors', metadata=dict(model_id='My random model'))
# Load an existing cache from disk
cache, metadata = load_cache(filename='test.safetensors')
Just updating the title of the PR for clarity. Now KV cache of any generation can be reused for other requests with these changes.
The code in server.py is modified accordingly to adapt to the changes made with generate_step. Prompt caching is available on server.py by default.
Thanks for the PR! However, most of this functionality should already be included in https://github.com/ml-explore/mlx-examples/pull/1015 and https://github.com/ml-explore/mlx-examples/pull/1026, so I will close this.
If there is anything here that those don't address please feel free to submit a follow up PR rebased on the latest. Thanks!