Feature: Prompt cache
Currently MLXLMCommon has some basic support for a cache, however it isn't persisted across calls to generate().
Even though it appears there could be a way to pass a KVCache to generate(), it ultimately must pass through the Sendable boundary if the app is to manage the cache. This isn't possible as MLXArray is not Sendable and also isn't desirable or necessary.
A prompt cache could be managed by the ModelContainer actor and stored in its context ModelContext.promptCache. Note that the prompt cache is an array of KVCache. In mlx_lm the PromptCache object also stores the token ids of the cached prompt and the model key to check if the model has changed.
We could implement a similar struct:
public struct PromptCache {
public let cache: [KVCache]
public let modelKey: String
public let tokens: MLXArray
}
The PromptCache struct could also have functions for trimming.
Functions analogous to mlx_lm's get_prompt_cache could go in the ModelContainer actor.
I'm currently having a go at implementing this. Interested in any suggestions on the best approach.
I don't think it belongs in ModelContainer -- that is scoped to the lifetime of the model (weights really). The KVCache is more like a session. Think of it in the case of a server that handles multiple chat sessions.
I maybe cheated a bit here: https://github.com/ml-explore/mlx-swift-examples/pull/277/files#diff-8821b90aba52905c4daae1bd76fb6ca81f8dcca661c3f3d33a25362eaa82f818R66
I made it inside the ModelContainer and the async state keeps it live.
I think it needs to live outside the ModelContainer to model the lifetime correctly. The ChatExample app would be a good use case -- it would be @State for the view ultimately. Can we use sending? Do we need another actor to contain it (and use nested visit calls)?
It isn't entirely obvious to me, so any investigation you do will be valuable!
To be honest I hadn't thought much about a longer lifecycle of the cache. I saw in cache.py there is a save_prompt_cache() which would allow it to persist across application launches.
If the cache is to persist outside of the ModelContainer.perform() Sendable action, it needs to be converted to something other than an MLXArray and copied across. Additionally, should the cache also then be sent back across the boundary and converted back to an MLXArray? It would seem to me that the latter is unnecessary unless the application wants to do some special cache manipulation, but that could still be handled as a Sendable action on ModelContainer.
Additionally if an app isn't concerned about having the cache persist across sessions then there could be an unnecessary performance impact of sending the cache across either once or twice per step. So having it syncing across the boundary should be an option rather than the default? Which means it needs to be stored with the ModelContainer.
The question is what is the performance impact of frequently syncing the cache? We would have to do some testing to find out.
If we take the approach of storing the cache in ModelContainer, and the app wants to persist it across sessions, it could be left to the app to explicitly grab the cache, e.g. ModelContainer.cacheToArray(). And it can do this as frequently as it likes.
I'm not much of an expert on Sendable so I'm not sure what tricks could be used to have the MLXArray stored outside of the ModelContainer. I think you are implying that it is possible without MLXArray being Sendable itself?
I'm not much of an expert on Sendable so I'm not sure what tricks could be used to have the MLXArray stored outside of the ModelContainer. I think you are implying that it is possible without MLXArray being Sendable itself?
Yes, that is my hope. There are some annotations like sending and consuming that can move non-Sendable data between isolation contexts. And actor can provide an isolation context that can itself be moved around. I don't think we want to serialize the KVCache into something that is Sendable.
At a minimum we can do something like this (not tested but I think the idea is sound):
class KVCacheContainer: @unchecked Sendable {
let lock = NSLock()
let kvCache: [KVCache]
func withLock<R>(_ visitor: ([KVCache]) -> R) -> R {
lock.withLock {
visitor(kvCache)
}
}
}
That guarantees exclusive access to the KVCache and the container itself is Sendable.
I don't think we want to serialize the KVCache into something that is Sendable.
I agree, let me try a few of those suggestions...
I think I am getting my head around this.
If we use an actor the only way that the [KVCache] array can be mutated is with an inout parameter and the closure must be synchronous (similar to ModelContainer.update()). We would only need to use this when creating a new cache (LanguageModel.newCache()) and that should be a relatively fast operation.
Any other mutations or function calls on an individual KVCache can be performed in an asynchronous closure.
It would need both synchronous and asynchronous closure functions:
actor KVCacheContainer {
var kvCache: [KVCache] = []
public func perform<R>(_ action: ([KVCache]) async throws -> R) async rethrows -> R
{
try await action(kvCache)
}
public func update<R>(_ action: (inout [KVCache]) throws -> R) rethrows -> R
{
try action(&kvCache)
}
}
It would be used something like:
// Create cache -- update must be inside as it is synchronous
await modelContainer.perform { context in
await cacheContainer.update { kvCache in
kvCache[0] = KVCacheSimple()
}
}
// Update cache -- either order is fine
await cacheContainer.perform { kvCache in
await modelContainer.perform { context in
kvCache[0].increment(n: 5)
}
}
I've got an implementation that is working and I've integrated it into the MLXChatExample. The changes were cleaner than I was expecting but there are a few caveats.
https://github.com/jolonf/mlx-swift-examples/tree/feature/prompt-caching
I've decided to call the actor PromptCache which is equivalent to KVCacheContainer above.
The issue of needing to mutate the [KVCache] array turned out to be a non-issue as the only time it needs to be changed is when it is recreated and as a result the entire PromptCache can be recreated with the new [KVCache].
public actor PromptCache {
public let cache: [KVCache]
public var tokens: MLXArray
public init(cache: [KVCache]) {
print("[PromptCache.init]")
self.cache = cache
self.tokens = []
}
...
Currently functions for getting the uncached prompt suffix and common prefix are in the PromptCache actor.
public func getUncachedSuffix(prompt: MLXArray) async -> MLXArray? {
There are no changes to ModelContainer or ModelContext.
Evaluate.swift generate() functions have a cache parameter added to pass in the cache.
public func generate(
input: LMInput, parameters: GenerateParameters, context: ModelContext, cache: [KVCache]?
) throws -> AsyncStream<Generation> {
KVCache and KVCacheSimple have isTrimmable() and trim() functions added.
These are the only changes to the library.
The main change to the MLXChatExample app is in MLXService:
// Generate response using the model
return try await modelContainer.perform { (context: ModelContext) in
let fullPrompt = try await context.processor.prepare(input: userInput)
let parameters = GenerateParameters(temperature: 0.7)
// Get the prompt cache
let cache: PromptCache
if let existingCache = self.promptCache[model.name] {
cache = existingCache
} else {
// Create cache if it doesn't exist yet
cache = PromptCache(cache: context.model.newCache(parameters: parameters))
promptCache[model.name] = cache
}
let lmInput: LMInput
/// Remove prefix from prompt that is already in cache
if let suffix = await cache.getUncachedSuffix(prompt: fullPrompt.text.tokens) {
lmInput = LMInput(text: LMInput.Text(tokens: suffix))
} else {
// If suffix is nil, the cache is inconsistent with the new prompt
// and the cache doesn't support trimming so create a new one here.
self.promptCache[model.name] = PromptCache(cache: context.model.newCache(parameters: parameters))
lmInput = fullPrompt
}
// TODO: cache.perform ...
// TODO: The generated tokens should be added to the prompt cache but not possible with AsyncStream
return try MLXLMCommon.generate(
input: lmInput, parameters: parameters, context: context, cache: await cache.cache)
}
Currently I have just used await cache.cache to pass the cache to generate(). It works, but could be isolated better.
Also currently it is up to the app to add newly generated tokens to PromptCache.tokens but AsyncStream doesn't return the generated tokens, so the previously generated response is not cached.
Interestingly with my testing with Qwen3 0.6B the last 5 tokens were always trimmed from the last prompt. This meant that the LLM's response would always have been trimmed anyway.
Is this along the lines of what you were thinking?
Just to clarify, the only required change to the project to support caching is adding the [KVCache] parameter to the generate() functions and adding isTrimmable() and trim() to KVCache/KVCacheSimple.
Technically the PromptCache doesn't need to be in the library and could be provided by the app.
The changes to Evaluate.swift and KVCache.swift are small and benign, they could be incorporated now. Apps could provide their own PromptCache until we include one. Or we could include the current one in the MLXChatExample for now.
So... I just realised that the projects I was using to test this were set to Swift 5 instead of Swift 6. So you can pretty much ignore what I said above 😞
As mentioned above, the most critical thing is adding the cache parameters to the generate() functions. I've created a pull request https://github.com/ml-explore/mlx-swift-examples/pull/312 that does this and adds the trim functions to KVCache.
It also has an example of implementing a prompt cache in MLXChatExample. The PromptCache has been moved into the app rather than the library and is an @unchecked Sendable class.
I haven't implemented any isolation on the cache at this stage.