mlx-swift-examples icon indicating copy to clipboard operation
mlx-swift-examples copied to clipboard

Is it possible to support concurrency in model container

Open Remember2015 opened this issue 9 months ago • 10 comments

https://github.com/ml-explore/mlx-swift-examples/blob/42b0d21a607b97c3b99d1e20626b7d9b14827b81/Libraries/MLXLMCommon/ModelContainer.swift#L9

As comment above mentioned, now it is single thread access only, but in some cases, concurrency is really useful to speed up.

Thanks for reply

Remember2015 avatar Jul 16 '25 06:07 Remember2015

No, but you could potentially use a ModelContext instead -- this is the payload that the ModelContainer uses. You would have to ensure that there are no concurrency issues manually.

davidkoski avatar Jul 16 '25 22:07 davidkoski

No, but you could potentially use a ModelContext instead -- this is the payload that the ModelContainer uses. You would have to ensure that there are no concurrency issues manually.

Thank you, i'll try it.

Remember2015 avatar Jul 18 '25 07:07 Remember2015

Hello, after reading codes further, I found that under framework do not support batch generation which means multiple ModelContext needed in parallel, am i right? And it seems that memory occupied by model weights will be multiple? And kv cache may also be a problem?

No, but you could potentially use a ModelContext instead -- this is the payload that the ModelContainer uses. You would have to ensure that there are no concurrency issues manually.

Remember2015 avatar Sep 16 '25 09:09 Remember2015

Hello, after reading codes further, I found that under framework do not support batch generation which means multiple ModelContext needed in parallel, am i right? And it seems that memory occupied by model weights will be multiple? And kv cache may also be a problem?

I am not sure what you mean by "batch generation". The prefill of the context uses the prefillStepSize as a batch:

  • https://github.com/ml-explore/mlx-swift-examples/blob/main/Libraries/MLXLLM/LLMModel.swift#L30

You can't do inference in batch exactly -- it prepares one token at a time and the next token relies on the state of the previous. The TokenIterator evaluates after every step, but if you were to write this loop yourself you could easily do inference N tokens at a time.

But maybe this isn't what you mean by batch generation?

As to the other question, no, you don't have to load the weights twice. For example:

let model = try await loadModel(id: "mlx-community/Qwen3-4B-4bit")
let session1 = ChatSession(model)
let session2 = ChatSession(model)

session1 and session2 will share the model weights but will have their own independent KVCaches.

davidkoski avatar Sep 16 '25 15:09 davidkoski

Hello, after reading codes further, I found that under framework do not support batch generation which means multiple ModelContext needed in parallel, am i right? And it seems that memory occupied by model weights will be multiple? And kv cache may also be a problem?

I am not sure what you mean by "batch generation". The prefill of the context uses the prefillStepSize as a batch:

  • https://github.com/ml-explore/mlx-swift-examples/blob/main/Libraries/MLXLLM/LLMModel.swift#L30

You can't do inference in batch exactly -- it prepares one token at a time and the next token relies on the state of the previous. The TokenIterator evaluates after every step, but if you were to write this loop yourself you could easily do inference N tokens at a time.

But maybe this isn't what you mean by batch generation?

As to the other question, no, you don't have to load the weights twice. For example:

let model = try await loadModel(id: "mlx-community/Qwen3-4B-4bit") let session1 = ChatSession(model) let session2 = ChatSession(model) session1 and session2 will share the model weights but will have their own independent KVCaches.

Davidkoski, thank you for ur detailed reply, "batch generation" may be inappropriate for what means do prompt generation in parallel, like that task A and B are working simultaneously, A has its own prompt with model X, B has different prompt with same model

Remember2015 avatar Sep 17 '25 02:09 Remember2015

Hello, after reading codes further, I found that under framework do not support batch generation which means multiple ModelContext needed in parallel, am i right? And it seems that memory occupied by model weights will be multiple? And kv cache may also be a problem?

I am not sure what you mean by "batch generation". The prefill of the context uses the prefillStepSize as a batch:

  • https://github.com/ml-explore/mlx-swift-examples/blob/main/Libraries/MLXLLM/LLMModel.swift#L30

You can't do inference in batch exactly -- it prepares one token at a time and the next token relies on the state of the previous. The TokenIterator evaluates after every step, but if you were to write this loop yourself you could easily do inference N tokens at a time.

But maybe this isn't what you mean by batch generation?

As to the other question, no, you don't have to load the weights twice. For example:

let model = try await loadModel(id: "mlx-community/Qwen3-4B-4bit") let session1 = ChatSession(model) let session2 = ChatSession(model) session1 and session2 will share the model weights but will have their own independent KVCaches.

but if I want session1 and session2 working simultaneously just like a server respond to 2 requests in the same time?

Remember2015 avatar Sep 17 '25 06:09 Remember2015

but if I want session1 and session2 working simultaneously just like a server respond to 2 requests in the same time?

That should work, though it does serialize on the GPU.

Specifically unevaluated MLXrray are not thread safe. The weights are evaluated on loading the model, so those are OK. The KVCache will be owned by the session, so that won't be shared across threads.

The one remaining piece is the random state (used in sampling) -- by default that is global and is an unevaluated MLXArray (e.g. it will crash). There is a way though, look at this test:

  • https://github.com/ml-explore/mlx-swift/blob/main/Tests/MLXTests/MLXRandomTests.swift#L237

  • https://swiftpackageindex.com/ml-explore/mlx-swift/main/documentation/mlx/withrandomstate(_:body:)-18ob4

You can have task-local random state. This could be managed in the code that calls respond/streamResponse.

I think that should work!

davidkoski avatar Sep 17 '25 15:09 davidkoski

https://github.com/user-attachments/assets/192a964a-58f6-4514-9903-58e181bac14c

I tried this; 3 different prompts running with Qwen 1.7B. Running on 3 different tasks

rudrankriyam avatar Sep 17 '25 16:09 rudrankriyam

but if I want session1 and session2 working simultaneously just like a server respond to 2 requests in the same time?

That should work, though it does serialize on the GPU.

Specifically unevaluated MLXrray are not thread safe. The weights are evaluated on loading the model, so those are OK. The KVCache will be owned by the session, so that won't be shared across threads.

The one remaining piece is the random state (used in sampling) -- by default that is global and is an unevaluated MLXArray (e.g. it will crash). There is a way though, look at this test:

  • https://github.com/ml-explore/mlx-swift/blob/main/Tests/MLXTests/MLXRandomTests.swift#L237
  • https://swiftpackageindex.com/ml-explore/mlx-swift/main/documentation/mlx/withrandomstate(_:body:)-18ob4

You can have task-local random state. This could be managed in the code that calls respond/streamResponse.

I think that should work!

Thanks for your patient reply, I'll read codes further about sampling implementation~

Remember2015 avatar Sep 18 '25 02:09 Remember2015

FPzw5_SCxvnKjByV.mp4 I tried this; 3 different prompts running with Qwen 1.7B. Running on 3 different tasks

That’s awesome, let me try it

Remember2015 avatar Sep 18 '25 02:09 Remember2015