TransformerEngine icon indicating copy to clipboard operation
TransformerEngine copied to clipboard

[JAX] cuBlasMp integration for CollectiveGemm custom op

Open denera opened this issue 3 months ago • 2 comments

Description

This PR integrates TE/common cuBlasMp bindings into the TE/JAX CollectiveGemm custom op.

Type of change

  • [ ] Documentation change (change only to the documentation, either a fix or a new content)
  • [ ] Bug fix (non-breaking change which fixes an issue)
  • [x] New feature (non-breaking change which adds functionality)
  • [ ] Breaking change (fix or feature that would cause existing functionality to not work as expected)
  • [ ] Infra/Build change
  • [ ] Code refactoring

Checklist:

  • [x] I have read and followed the contributing guidelines
  • [x] The functionality is complete
  • [x] I have commented my code, particularly in hard-to-understand areas
  • [ ] I have made corresponding changes to the documentation
  • [x] My changes generate no new warnings
  • [ ] I have added tests that prove my fix is effective or that my feature works
  • [ ] New and existing unit tests pass locally with my changes

denera avatar Nov 07 '25 13:11 denera

Greptile Overview

Greptile Summary

Integrates cuBlasMp backend into JAX CollectiveGemm operations alongside existing userbuffers backend, controlled by NVTE_WITH_CUBLASMP environment variable. The implementation introduces a CollectiveGemmPlan abstraction to support both backends through runtime dispatch.

Critical compilation errors found:

  • Missing semicolon in CollectiveGemmPlan struct definition (cgemm_helper.h:179)
  • Undefined function calls get_cublasmp_context() and get_userbuffers_context() in bootstrap code (cgemm_helper.cpp:141,144)
  • Incorrect object instantiation of CommOverlapP2PBase - creates stack object instead of heap pointer (cgemm_helper.cpp:205-214)
  • Undefined variable userbuffers_ctx used instead of ctx (gemm.cpp:304)

Design considerations:

  • Runtime environment variable check in hot path may impact performance - consider caching the decision
  • The dual backend approach adds complexity - ensure both paths are tested thoroughly

Confidence Score: 0/5

  • This PR cannot compile due to multiple syntax errors
  • Score reflects critical compilation failures including missing semicolon, undefined functions, and incorrect object lifetime management that will prevent successful build
  • All C++ files (cgemm_helper.h, cgemm_helper.cpp, gemm.cpp) require fixes before this can compile

Important Files Changed

File Analysis

Filename Score Overview
build_tools/jax.py 5/5 adds NVTE_WITH_CUBLASMP compiler flag when environment variable is set - straightforward build configuration change
transformer_engine/jax/csrc/extensions/cgemm_helper.h 1/5 introduces CollectiveGemmPlan struct to abstract cuBlasMp and userbuffers backends - has missing semicolon causing compilation failure
transformer_engine/jax/csrc/extensions/cgemm_helper.cpp 1/5 refactors get_executor to get_plan with dual backend support - has undefined function calls and incorrect object instantiation causing compilation failures
transformer_engine/jax/csrc/extensions/gemm.cpp 2/5 integrates cuBlasMp API calls for AG+GEMM and GEMM+RS operations - has undefined variable reference causing compilation failure

Sequence Diagram

sequenceDiagram
    participant JAX as JAX Frontend
    participant Init as CollectiveGemmInitFFI
    participant Registry as CollectiveGemmPlanRegistry
    participant Plan as CollectiveGemmPlan
    participant CuBlasMp as cuBlasMp Backend
    participant UB as Userbuffers Backend
    participant Gemm as GemmFFI

    Note over JAX,Gemm: Initialization Phase
    JAX->>Init: Initialize collective GEMM
    Init->>Registry: get_plan(buffer_shape, dtype, collective_op)
    
    alt NVTE_WITH_CUBLASMP=true
        Registry->>CuBlasMp: nvte_comm_gemm_ctx_create(comm, nranks, rank)
        CuBlasMp-->>Registry: NVTECommGemmCtx*
        Registry->>Plan: Create plan with cublasmp_context
    else NVTE_WITH_CUBLASMP=false
        Registry->>UB: new CommOverlapP2PBase(...)
        UB-->>Registry: CommOverlapCore*
        Registry->>Plan: Create plan with userbuffers_context
    end
    
    Plan-->>Registry: CollectiveGemmPlan*
    Registry-->>Init: Plan cached and returned

    Note over JAX,Gemm: Execution Phase
    JAX->>Gemm: Execute GEMM operation
    Gemm->>Registry: get_plan(buffer_shape, dtype, collective_op)
    Registry-->>Gemm: Retrieve cached plan
    
    alt collective_op=REDUCE_SCATTER
        alt use_cublasmp=true
            Gemm->>CuBlasMp: nvte_gemm_reduce_scatter(ctx, m, n, k, ...)
            CuBlasMp-->>Gemm: Complete GEMM+RS
        else use_cublasmp=false
            Gemm->>UB: ctx->split_overlap_rs(...)
            UB-->>Gemm: Complete GEMM+RS
        end
    else collective_op=ALL_GATHER
        alt use_cublasmp=true
            Gemm->>CuBlasMp: nvte_all_gather_gemm(ctx, m, n, k, ...)
            CuBlasMp-->>Gemm: Complete AG+GEMM
        else use_cublasmp=false
            Gemm->>UB: ctx->split_overlap_ag(...)
            UB-->>Gemm: Complete AG+GEMM
        end
    end
    
    Gemm-->>JAX: Return result

greptile-apps[bot] avatar Nov 14 '25 19:11 greptile-apps[bot]

Hi, could you add some unit tests?

phu0ngng avatar Nov 14 '25 21:11 phu0ngng

Hi @denera - This needs to go into TE2.10. I see @phu0ngng 's last comment to add some tests. Please let us know if that was already done and if you think it will take much longer. @KshitijLakhani to keep an eye on this to cherry pick into TE2.10 when ready.

nvMelissa avatar Nov 19 '25 21:11 nvMelissa