TransformerEngine icon indicating copy to clipboard operation
TransformerEngine copied to clipboard

Context parallelism with MLA

Open SuperCB opened this issue 11 months ago • 4 comments

I have a question regarding FusedAttention: Why doesn't it support context parallelism with MLA (Multi-head Layer Attention)? What are the technical limitations preventing this compatibility?"

SuperCB avatar Mar 08 '25 15:03 SuperCB

Hi @SuperCB

You mean Multi-head Latent attention which is used by Deepseek? Technically, nothing should stop us from doing it, we just have not done it yet. Considering popularity of MLA/Deepseek, we should add this support for sure. We will do it. Thanks for bringing this to our attention.

xrennvidia avatar Mar 10 '25 21:03 xrennvidia

I am working on it too. I found that the function AttnFuncWithCPAndQKVOA2A can support context parallelism for mla? Is my conclusion correct, and what are the main reasons currently preventing mla from supporting context parallelism?

SuperCB avatar Mar 11 '25 02:03 SuperCB

Yeah, A2A implementation probably can work with MLA out of the box. AttnFuncWithCPAndKVAllGather might work for MLA also.

P2P cannot work because it concats K and V into a single tensor for communication, different head_dim of K and V prevents us from doing the concat, but this should be addressable.

As I said, technically, there should be no reason preventing MLA+CP, at least I do not know the reasons now, I might find something after I start to work on this.

xrennvidia avatar Mar 11 '25 02:03 xrennvidia

I think we can support MLA+CP in P2P by padding the v value, which ensures minimal modifications to the original code. I am currently attempting to use this method.

Image

SuperCB avatar Mar 11 '25 10:03 SuperCB

Close by #1729 .

yaox12 avatar Jun 26 '25 04:06 yaox12