[JAX] cuBlasMp integration for CollectiveGemm custom op
Description
This PR integrates TE/common cuBlasMp bindings into the TE/JAX CollectiveGemm custom op.
Type of change
- [ ] Documentation change (change only to the documentation, either a fix or a new content)
- [ ] Bug fix (non-breaking change which fixes an issue)
- [x] New feature (non-breaking change which adds functionality)
- [ ] Breaking change (fix or feature that would cause existing functionality to not work as expected)
- [ ] Infra/Build change
- [ ] Code refactoring
Checklist:
- [x] I have read and followed the contributing guidelines
- [x] The functionality is complete
- [x] I have commented my code, particularly in hard-to-understand areas
- [ ] I have made corresponding changes to the documentation
- [x] My changes generate no new warnings
- [ ] I have added tests that prove my fix is effective or that my feature works
- [ ] New and existing unit tests pass locally with my changes
Greptile Overview
Greptile Summary
Integrates cuBlasMp backend into JAX CollectiveGemm operations alongside existing userbuffers backend, controlled by NVTE_WITH_CUBLASMP environment variable. The implementation introduces a CollectiveGemmPlan abstraction to support both backends through runtime dispatch.
Critical compilation errors found:
- Missing semicolon in
CollectiveGemmPlanstruct definition (cgemm_helper.h:179) - Undefined function calls
get_cublasmp_context()andget_userbuffers_context()in bootstrap code (cgemm_helper.cpp:141,144) - Incorrect object instantiation of
CommOverlapP2PBase- creates stack object instead of heap pointer (cgemm_helper.cpp:205-214) - Undefined variable
userbuffers_ctxused instead ofctx(gemm.cpp:304)
Design considerations:
- Runtime environment variable check in hot path may impact performance - consider caching the decision
- The dual backend approach adds complexity - ensure both paths are tested thoroughly
Confidence Score: 0/5
- This PR cannot compile due to multiple syntax errors
- Score reflects critical compilation failures including missing semicolon, undefined functions, and incorrect object lifetime management that will prevent successful build
- All C++ files (cgemm_helper.h, cgemm_helper.cpp, gemm.cpp) require fixes before this can compile
Important Files Changed
File Analysis
| Filename | Score | Overview |
|---|---|---|
| build_tools/jax.py | 5/5 | adds NVTE_WITH_CUBLASMP compiler flag when environment variable is set - straightforward build configuration change |
| transformer_engine/jax/csrc/extensions/cgemm_helper.h | 1/5 | introduces CollectiveGemmPlan struct to abstract cuBlasMp and userbuffers backends - has missing semicolon causing compilation failure |
| transformer_engine/jax/csrc/extensions/cgemm_helper.cpp | 1/5 | refactors get_executor to get_plan with dual backend support - has undefined function calls and incorrect object instantiation causing compilation failures |
| transformer_engine/jax/csrc/extensions/gemm.cpp | 2/5 | integrates cuBlasMp API calls for AG+GEMM and GEMM+RS operations - has undefined variable reference causing compilation failure |
Sequence Diagram
sequenceDiagram
participant JAX as JAX Frontend
participant Init as CollectiveGemmInitFFI
participant Registry as CollectiveGemmPlanRegistry
participant Plan as CollectiveGemmPlan
participant CuBlasMp as cuBlasMp Backend
participant UB as Userbuffers Backend
participant Gemm as GemmFFI
Note over JAX,Gemm: Initialization Phase
JAX->>Init: Initialize collective GEMM
Init->>Registry: get_plan(buffer_shape, dtype, collective_op)
alt NVTE_WITH_CUBLASMP=true
Registry->>CuBlasMp: nvte_comm_gemm_ctx_create(comm, nranks, rank)
CuBlasMp-->>Registry: NVTECommGemmCtx*
Registry->>Plan: Create plan with cublasmp_context
else NVTE_WITH_CUBLASMP=false
Registry->>UB: new CommOverlapP2PBase(...)
UB-->>Registry: CommOverlapCore*
Registry->>Plan: Create plan with userbuffers_context
end
Plan-->>Registry: CollectiveGemmPlan*
Registry-->>Init: Plan cached and returned
Note over JAX,Gemm: Execution Phase
JAX->>Gemm: Execute GEMM operation
Gemm->>Registry: get_plan(buffer_shape, dtype, collective_op)
Registry-->>Gemm: Retrieve cached plan
alt collective_op=REDUCE_SCATTER
alt use_cublasmp=true
Gemm->>CuBlasMp: nvte_gemm_reduce_scatter(ctx, m, n, k, ...)
CuBlasMp-->>Gemm: Complete GEMM+RS
else use_cublasmp=false
Gemm->>UB: ctx->split_overlap_rs(...)
UB-->>Gemm: Complete GEMM+RS
end
else collective_op=ALL_GATHER
alt use_cublasmp=true
Gemm->>CuBlasMp: nvte_all_gather_gemm(ctx, m, n, k, ...)
CuBlasMp-->>Gemm: Complete AG+GEMM
else use_cublasmp=false
Gemm->>UB: ctx->split_overlap_ag(...)
UB-->>Gemm: Complete AG+GEMM
end
end
Gemm-->>JAX: Return result
Hi, could you add some unit tests?
Hi @denera - This needs to go into TE2.10. I see @phu0ngng 's last comment to add some tests. Please let us know if that was already done and if you think it will take much longer. @KshitijLakhani to keep an eye on this to cherry pick into TE2.10 when ready.