CUDA.jl icon indicating copy to clipboard operation
CUDA.jl copied to clipboard

Easier way to do mixed-mode matrix multiplication

Open zygi opened this issue 2 years ago • 6 comments

Describe the bug In deep learning, people often use fp16 matmuls with fp32 accumulation (cuBLAS compute type) as a balance between performance and preserving numerical accuracy. In Torch, if you do a fp16 by fp16 matmul, fp32 compute type is the default behavior. In CUDA.jl the default is fp16 accumulation, and it doesn't seem to be possible to easily get fp32-accum behavior.

It would be great if there was either a toggle to change this behavior, similar to math_mode, or maybe even to make the fp32-accum behavior the default.

Specifically, currently fp16 gemm! is dispatched to cublasHgemm whereas the suggested behavior (and the way Torch does it) is to dispatch to cublasSgemm but set the input/output datatype args to be fp16.

This also applies to batched matmuls, where CUDA.jl dispatches to cublasHgemmBatched, and maybe batched matvec products.

I'm happy to open a PR if the maintainers decide it's ok to change the current behavior without introducing a setting. If a setting is needed it might be better for someone more familiar with the project's structure to do this.

To reproduce Use NSight Compute to see that the kernel used is ampere_h1688gemm_128x128_ldg8_stages_32x1_nn or something with h1688.

Version info

Details

julia> versioninfo() Julia Version 1.9.2 Commit e4ee485e909 (2023-07-05 09:39 UTC) Platform Info: OS: Linux (x86_64-linux-gnu) CPU: 32 × 13th Gen Intel(R) Core(TM) i9-13900K WORD_SIZE: 64 LIBM: libopenlibm LLVM: libLLVM-14.0.6 (ORCJIT, goldmont) Threads: 1 on 32 virtual cores Environment: JULIA_EDITOR = code JULIA_NUM_THREADS =

julia> CUDA.versioninfo() CUDA runtime 12.1, artifact installation CUDA driver 12.2 NVIDIA driver 535.86.5

CUDA libraries:

  • CUBLAS: 12.1.3
  • CURAND: 10.3.2
  • CUFFT: 11.0.2
  • CUSOLVER: 11.4.5
  • CUSPARSE: 12.1.0
  • CUPTI: 18.0.0
  • NVML: 12.0.0+535.86.5

Julia packages:

  • CUDA: 4.4.0
  • CUDA_Driver_jll: 0.5.0+1
  • CUDA_Runtime_jll: 0.6.0+0

Toolchain:

  • Julia: 1.9.2
  • LLVM: 14.0.6
  • PTX ISA support: 3.2, 4.0, 4.1, 4.2, 4.3, 5.0, 6.0, 6.1, 6.3, 6.4, 6.5, 7.0, 7.1, 7.2, 7.3, 7.4, 7.5
  • Device capability support: sm_37, sm_50, sm_52, sm_53, sm_60, sm_61, sm_62, sm_70, sm_72, sm_75, sm_80, sm_86

1 device: 0: NVIDIA GeForce RTX 4090 (sm_89, 21.899 GiB / 23.988 GiB available)

zygi avatar Aug 05 '23 19:08 zygi

Oops, I was misreading the code and focused on the wrong path. The CUDA.jl behavior is defined here since we're using gemmEx. In that case we probably need much less piping. Can we just make it

    if sig === (Float16, Float16)
        # NOTE: Float16=Float16*Float16 can also happen in 32-bit compute
        if reduced_precision ===  :TensorFloat32
            return math_mode==CUDA.PEDANTIC_MATH ? CUDA.CUBLAS.CUBLAS_COMPUTE_32F_PEDANTIC : CUDA.CUBLAS.CUBLAS_COMPUTE_32F
        else
            return math_mode==CUDA.PEDANTIC_MATH ? CUDA.CUBLAS.CUBLAS_COMPUTE_16F_PEDANTIC : CUDA.CUBLAS.CUBLAS_COMPUTE_16F
        end
    end

?

zygi avatar Aug 07 '23 18:08 zygi

it doesn't seem to be possible to easily get fp32-accum behavior.

The API is as follows:

julia> using CUDA

julia> A = CUDA.rand(Float16, 2, 2)
2×2 CuArray{Float16, 2, CUDA.Mem.DeviceBuffer}:
 0.4697  0.956
 0.718   0.79

julia> B = CUDA.rand(Float16, 2, 2)
2×2 CuArray{Float16, 2, CUDA.Mem.DeviceBuffer}:
 0.8115  0.6846
 0.963   0.2109

julia> C = CUDA.zeros(Float32, 2, 2)
2×2 CuArray{Float32, 2, CUDA.Mem.DeviceBuffer}:
 0.0  0.0
 0.0  0.0

julia> using LinearAlgebra

julia> mul!(C, A, B)
2×2 CuArray{Float32, 2, CUDA.Mem.DeviceBuffer}:
 1.30177  0.523229
 1.34321  0.658015

It would be great if there was either a toggle to change this behavior, similar to math_mode, or maybe even to make the fp32-accum behavior the default.

Changing the default behavior of *(::CuArray{Float16}, ::CuArray{Float16}) is not easily possible. It would make the function type unstable, and be highly surprising to other Julia developers expecting * to behave as it does everywhere else.

As such, I don't think this is "just" a CUDA.jl issue, and if you want a different behavior it's probably better to discuss this in a place where people familiar with Julia's array interfaces can chime in. A Discourse post, maybe?

Can we just make it

    if sig === (Float16, Float16)
        # NOTE: Float16=Float16*Float16 can also happen in 32-bit compute
        if reduced_precision ===  :TensorFloat32
            return math_mode==CUDA.PEDANTIC_MATH ? CUDA.CUBLAS.CUBLAS_COMPUTE_32F_PEDANTIC : CUDA.CUBLAS.CUBLAS_COMPUTE_32F
        else
            return math_mode==CUDA.PEDANTIC_MATH ? CUDA.CUBLAS.CUBLAS_COMPUTE_16F_PEDANTIC : CUDA.CUBLAS.CUBLAS_COMPUTE_16F
        end
    end

No, that just changes the computational domain. The output container has been determined and allocated at that point already, as that happens in LinearAlgebra.jl and not in CUDA.jl (see above; is why this probably warrants a wider discussion).

maleadt avatar Aug 08 '23 09:08 maleadt

Thanks for the reply!

No, that just changes the computational domain

sorry for being unclear, this is exactly what I had in mind. The desired behavior is to read in two Float16 matrices and output a Float16 matrix, but do the internal computation as Float32.

(fwiw I tested that change and confirmed that it then dispatches to the same kernel as default-Torch)

The motivation for why this is desirable:

  • Matmuls involve a large number of addition operations (dot(A[k,:][:,m])), leading to significant accuracy loss
  • So, people discovered that empirically using pure FP16 matmuls in neural networks (specifically transformers) doesn't work well
  • But, people also realized that for several past generations of CUDA accelerators (and very likely in the future), gpu matmul is memory bandwidth bound, not compute bound.
  • So, what if instead we do matmul, but in the hot loop each thread loads fp16s (cheap), casts them into fp32, keeps accumulating them in fp32 (more expensive but abundant), then when it's done with its assigned blocks, casts it back to fp16 and writes it out to memory.
  • It turns out that in this case the matmul is still significantly faster than normal fp32xfp32 matmul, but still preserves enough accuracy to be useful in practice

zygi avatar Aug 08 '23 15:08 zygi

No, that just changes the computational domain

sorry for being unclear, this is exactly what I had in mind. The desired behavior is to read in two Float16 matrices and output a Float16 matrix, but do the internal computation as Float32.

Ah OK, I was confused by the mention of cublasSgemm, where AFAIK you can't do this (the input/output types are Float32).

I'd be OK to add this to the set of math modes, or even to default to using 32-bits for Float16 matmul. Maybe some other ML people should chime in here; cc @ToucheSir @DhairyaLGandhi. I do feel like we should have a better way to control this though. pedantic/default/fast math mode doesn't really fit this, and neither does math_precision == :TensorFloat32. Does Torch allow configuring this?

maleadt avatar Aug 18 '23 09:08 maleadt

Ah OK, I was confused by the mention of cublasSgemm, where AFAIK you can't do this (the input/output types are Float32).

Sorry, yes, I meant cublasSgemmEx there

Does Torch allow configuring this?

They do: https://github.com/pytorch/pytorch/blob/68cb854d73458a14684d584c25c22b17eb79dfca/aten/src/ATen/cuda/CUDABlas.cpp#L506

The PR where the setting was introduced: https://github.com/pytorch/pytorch/commit/790763b0feffbfa5dd9fb4ed6c6a0ac35ef35fa2

zygi avatar Aug 18 '23 20:08 zygi

I had a read through that code and the docs at https://pytorch.org/docs/stable/notes/cuda.html#fp16reducedprecision as well. It doesn't look like PyTorch lets you configure the computation type away from fp32? In that sense there doesn't seem to be a switch between what cublasHgemm and cublasSgemmEx does. Is the former really so bad that it's not even worth supporting it in a ML context?

ToucheSir avatar Aug 18 '23 21:08 ToucheSir