mlx-examples icon indicating copy to clipboard operation
mlx-examples copied to clipboard

Enable caching for 'generate' and 'stream_generate' functions to ensure persistence of cache across multiple requests

Open nath1295 opened this issue 1 year ago • 2 comments

  1. Add two new data classes called CacheHistory and StepOutput for storing cache history along with the token history
  2. Add the option to return cache in "generate" and "stream_generate" for further cache reuse.
  3. Add two functions to save and load cache from disk.
  4. "prompt" argument in the "generate" and "stream_generate" is no longer a suffix for the cache history. It will be the full prompt. In "generate_step", there is a check to find out the index of the maximum shared prefix between the list of token ids from the new prompt and the token ids from the history prompt.

Usage

from mlx_lm import load, stream_generate
from mlx_lm.utils import save_cache, load_cache

model, tokenizer = load('/Path/to/model')

prompt = 'Your long prompt here...'

# First generation without prompt cache history
for i, cache in stream_generate(model=model, 
        tokenizer=tokenizer, prompt=prompt, max_tokens=100, return_cache=True, verbose=True):
    print(i, end='')
# Processing prompt (1431/1431): 100%|██████████| 3/3 [00:02<00:00,  1.50it/s]
# Prompt preprocessing time for 1431 tokens: 2.007s (713.1801 tok/sec)

# Second generation with prompt cache history
for i, cache in stream_generate(model=model, 
        tokenizer=tokenizer, prompt=prompt, max_tokens=100, return_cache=True, verbose=True, cache_history=cache):
    new += i
    print(i, end='')
# Processing prompt (1/1): 100%|██████████| 1/1 [00:00<00:00, 595.61it/s]
# Prompt preprocessing time for 1 tokens: 0.001921s (520.6299 tok/sec)

# Save the cache history to use later
save_cache(cache, filename='test.safetensors', metadata=dict(model_id='My random model'))

# Load an existing cache from disk
cache, metadata = load_cache(filename='test.safetensors')

nath1295 avatar Sep 17 '24 20:09 nath1295

Just updating the title of the PR for clarity. Now KV cache of any generation can be reused for other requests with these changes.

nath1295 avatar Sep 18 '24 10:09 nath1295

The code in server.py is modified accordingly to adapt to the changes made with generate_step. Prompt caching is available on server.py by default.

nath1295 avatar Sep 27 '24 16:09 nath1295

Thanks for the PR! However, most of this functionality should already be included in https://github.com/ml-explore/mlx-examples/pull/1015 and https://github.com/ml-explore/mlx-examples/pull/1026, so I will close this.

If there is anything here that those don't address please feel free to submit a follow up PR rebased on the latest. Thanks!

awni avatar Oct 12 '24 21:10 awni