CUDA.jl icon indicating copy to clipboard operation
CUDA.jl copied to clipboard

[CUDNN] Support BFloat16

Open AntonOresten opened this issue 2 months ago • 9 comments

This PR defines methods for making cuDNN work with BFloat16s.BFloat16.

In the following example, I show how the new methods fixes the BFloat16 backward pass of Flux.logitcrossentropy:

Before

Note: Core.BFloat16 === BFloat16s.BFloat16, but I didn't explicitly import in this REPL session.

julia> x, y = CUDA.randn(Core.BFloat16, 32), CUDA.randn(Core.BFloat16, 32); Flux.gradient(x) do x
           Flux.logitcrossentropy(x, y)
       end
ERROR: MethodError: no method matching cudnnDataType(::Type{Core.BFloat16})
The function `cudnnDataType` exists, but no method is defined for this combination of argument types.

Closest candidates are:
  cudnnDataType(::Type{Float16})
   @ cuDNN ~/.julia/packages/cuDNN/vKsqU/src/util.jl:7
  cudnnDataType(::Type{Float32})
   @ cuDNN ~/.julia/packages/cuDNN/vKsqU/src/util.jl:8
  cudnnDataType(::Type{Float64})
   @ cuDNN ~/.julia/packages/cuDNN/vKsqU/src/util.jl:9
  ...

Stacktrace:
  [1] cuDNN.cudnnTensorDescriptor(array::CuArray{Core.BFloat16, 4, CUDA.DeviceMemory}; format::cuDNN.cudnnTensorFormat_t, dims::Vector{Int32})
    @ cuDNN ~/.julia/packages/cuDNN/vKsqU/src/tensor.jl:9
  [2] cudnnSoftmaxForward!(y::CuArray{…}, x::CuArray{…}; o::@Kwargs{…})
    @ cuDNN ~/.julia/packages/cuDNN/vKsqU/src/softmax.jl:17
  [3] logsoftmax!(y::CuArray{Core.BFloat16, 1, CUDA.DeviceMemory}, x::CuArray{Core.BFloat16, 1, CUDA.DeviceMemory}; dims::Int64)
    @ NNlibCUDACUDNNExt ~/.julia/packages/NNlib/1TYHL/ext/NNlibCUDACUDNNExt/softmax.jl:90
  [4] logsoftmax!
    @ ~/.julia/packages/NNlib/1TYHL/ext/NNlibCUDACUDNNExt/softmax.jl:87 [inlined]
  [5] #logsoftmax#41
    @ ~/.julia/packages/NNlib/1TYHL/ext/NNlibCUDACUDNNExt/softmax.jl:20 [inlined]
  [6] logsoftmax
    @ ~/.julia/packages/NNlib/1TYHL/ext/NNlibCUDACUDNNExt/softmax.jl:19 [inlined]
  [7] #rrule#109
    @ ~/.julia/packages/NNlib/1TYHL/src/softmax.jl:129 [inlined]
  [8] rrule
    @ ~/.julia/packages/NNlib/1TYHL/src/softmax.jl:128 [inlined]
  [9] rrule
    @ ~/.julia/packages/ChainRulesCore/Vsbj9/src/rules.jl:144 [inlined]
 [10] chain_rrule_kw
    @ ~/.julia/packages/Zygote/55SqB/src/compiler/chainrules.jl:246 [inlined]
 [11] macro expansion
    @ ~/.julia/packages/Zygote/55SqB/src/compiler/interface2.jl:-1 [inlined]
 [12] _pullback
    @ ~/.julia/packages/Zygote/55SqB/src/compiler/interface2.jl:81 [inlined]
 [13] #logitcrossentropy#20
    @ ~/.julia/packages/Flux/uRn8o/src/losses/functions.jl:272 [inlined]
 [14] _pullback(::Zygote.Context{…}, ::Flux.Losses.var"##logitcrossentropy#20", ::Int64, ::typeof(mean), ::typeof(Flux.Losses.logitcrossentropy), ::CuArray{…}, ::CuArray{…})
    @ Zygote ~/.julia/packages/Zygote/55SqB/src/compiler/interface2.jl:0
 [15] _pullback(::Zygote.Context{…}, ::typeof(Flux.Losses.logitcrossentropy), ::CuArray{…}, ::CuArray{…})
    @ Zygote ~/.julia/packages/Zygote/55SqB/src/compiler/interface2.jl:81
 [16] #8
    @ ./REPL[14]:2 [inlined]
 [17] _pullback(ctx::Zygote.Context{false}, f::var"#8#9", args::CuArray{Core.BFloat16, 1, CUDA.DeviceMemory})
    @ Zygote ~/.julia/packages/Zygote/55SqB/src/compiler/interface2.jl:0
 [18] pullback(f::Function, cx::Zygote.Context{false}, args::CuArray{Core.BFloat16, 1, CUDA.DeviceMemory})
    @ Zygote ~/.julia/packages/Zygote/55SqB/src/compiler/interface.jl:96
 [19] pullback
    @ ~/.julia/packages/Zygote/55SqB/src/compiler/interface.jl:94 [inlined]
 [20] gradient(f::Function, args::CuArray{Core.BFloat16, 1, CUDA.DeviceMemory})
    @ Zygote ~/.julia/packages/Zygote/55SqB/src/compiler/interface.jl:153
 [21] #gradient#1
    @ ~/.julia/packages/Flux/uRn8o/src/gradient.jl:44 [inlined]
 [22] gradient(f::Function, args::CuArray{Core.BFloat16, 1, CUDA.DeviceMemory})
    @ Flux ~/.julia/packages/Flux/uRn8o/src/gradient.jl:31
 [23] top-level scope
    @ REPL[14]:1
Some type information was truncated. Use `show(err)` to see complete types.
After defining cudnnDataType(::Type{BFloat16})
julia> x, y = CUDA.randn(Core.BFloat16, 32), CUDA.randn(Core.BFloat16, 32); Flux.gradient(x) do x
           Flux.logitcrossentropy(x, y)
       end
ERROR: Unknown tensor type Core.BFloat16
Stacktrace:
  [1] error(s::String)
    @ Base ./error.jl:44
  [2] scalingParameter(T::Type, val::Int64)
    @ cuDNN ~/.julia/packages/cuDNN/vKsqU/src/util.jl:34
  [3] cudnnSoftmaxForwardWithDefaults(x::CuArray{…}; y::CuArray{…}, algo::cuDNN.cudnnSoftmaxAlgorithm_t, mode::cuDNN.cudnnSoftmaxMode_t, alpha::Int64, beta::Int64, format::cuDNN.cudnnTensorFormat_t, xDesc::cuDNN.cudnnTensorDescriptor, yDesc::cuDNN.cudnnTensorDescriptor)
    @ cuDNN ~/.julia/packages/cuDNN/vKsqU/src/softmax.jl:34
  [4] cudnnSoftmaxForward!(y::CuArray{…}, x::CuArray{…}; o::@Kwargs{…})
    @ cuDNN ~/.julia/packages/cuDNN/vKsqU/src/softmax.jl:17
  [5] logsoftmax!(y::CuArray{Core.BFloat16, 1, CUDA.DeviceMemory}, x::CuArray{Core.BFloat16, 1, CUDA.DeviceMemory}; dims::Int64)
    @ NNlibCUDACUDNNExt ~/.julia/packages/NNlib/1TYHL/ext/NNlibCUDACUDNNExt/softmax.jl:90
  [6] logsoftmax!
    @ ~/.julia/packages/NNlib/1TYHL/ext/NNlibCUDACUDNNExt/softmax.jl:87 [inlined]
  [7] #logsoftmax#41
    @ ~/.julia/packages/NNlib/1TYHL/ext/NNlibCUDACUDNNExt/softmax.jl:20 [inlined]
  [8] logsoftmax
    @ ~/.julia/packages/NNlib/1TYHL/ext/NNlibCUDACUDNNExt/softmax.jl:19 [inlined]
  [9] #rrule#109
    @ ~/.julia/packages/NNlib/1TYHL/src/softmax.jl:129 [inlined]
 [10] rrule
    @ ~/.julia/packages/NNlib/1TYHL/src/softmax.jl:128 [inlined]
 [11] rrule
    @ ~/.julia/packages/ChainRulesCore/Vsbj9/src/rules.jl:144 [inlined]
 [12] chain_rrule_kw
    @ ~/.julia/packages/Zygote/55SqB/src/compiler/chainrules.jl:246 [inlined]
 [13] macro expansion
    @ ~/.julia/packages/Zygote/55SqB/src/compiler/interface2.jl:-1 [inlined]
 [14] _pullback
    @ ~/.julia/packages/Zygote/55SqB/src/compiler/interface2.jl:81 [inlined]
 [15] #logitcrossentropy#20
    @ ~/.julia/packages/Flux/uRn8o/src/losses/functions.jl:272 [inlined]
 [16] _pullback(::Zygote.Context{…}, ::Flux.Losses.var"##logitcrossentropy#20", ::Int64, ::typeof(mean), ::typeof(Flux.Losses.logitcrossentropy), ::CuArray{…}, ::CuArray{…})
    @ Zygote ~/.julia/packages/Zygote/55SqB/src/compiler/interface2.jl:0
 [17] _pullback(::Zygote.Context{false}, ::typeof(Flux.Losses.logitcrossentropy), ::CuArray{Core.BFloat16, 1, CUDA.DeviceMemory}, ::CuArray{Core.BFloat16, 1, CUDA.DeviceMemory})
    @ Zygote ~/.julia/packages/Zygote/55SqB/src/compiler/interface2.jl:81
 [18] #11
    @ ./REPL[19]:2 [inlined]
 [19] _pullback(ctx::Zygote.Context{false}, f::var"#11#12", args::CuArray{Core.BFloat16, 1, CUDA.DeviceMemory})
    @ Zygote ~/.julia/packages/Zygote/55SqB/src/compiler/interface2.jl:0
 [20] pullback(f::Function, cx::Zygote.Context{false}, args::CuArray{Core.BFloat16, 1, CUDA.DeviceMemory})
    @ Zygote ~/.julia/packages/Zygote/55SqB/src/compiler/interface.jl:96
 [21] pullback
    @ ~/.julia/packages/Zygote/55SqB/src/compiler/interface.jl:94 [inlined]
 [22] gradient(f::Function, args::CuArray{Core.BFloat16, 1, CUDA.DeviceMemory})
    @ Zygote ~/.julia/packages/Zygote/55SqB/src/compiler/interface.jl:153
 [23] #gradient#1
    @ ~/.julia/packages/Flux/uRn8o/src/gradient.jl:44 [inlined]
 [24] gradient(f::Function, args::CuArray{Core.BFloat16, 1, CUDA.DeviceMemory})
    @ Flux ~/.julia/packages/Flux/uRn8o/src/gradient.jl:31
 [25] top-level scope
    @ REPL[19]:1
Some type information was truncated. Use `show(err)` to see complete types.
After defining scalingParameter(::Type{BFloat16}, val)
julia> x, y = CUDA.randn(Core.BFloat16, 32), CUDA.randn(Core.BFloat16, 32); Flux.gradient(x) do x
           Flux.logitcrossentropy(x, y)
       end
(Core.BFloat16[0.19335938, 0.32226562, -0.23828125, -0.85546875, 0.953125, 0.12207031, 1.15625, -0.64453125, -0.103515625, 0.61328125  …  0.4453125, -1.203125, 1.0234375, -1.46875, 0.19628906, -0.87890625, -1.3203125, 1.515625, 0.6484375, 0.44921875],)

I also define a cptr method for consistency, but it appears the function isn't used anywhere.

Tests are added for softmax, activations, and pooling. I initially also tested convolutions, normalization, RNNs, and MHA but they don't appear to support BFloat16.

Adding BFloat16s.jl as a dependency does not affect compilation since it's already a dependency of CUDA.jl.

Along with my proposed fix in https://github.com/FluxML/Optimisers.jl/issues/215, this has allowed me to train LLMs in BFloat16 with Flux.jl in Julia v1.12. I am still tinkering with Optimisers.jl, but these together would be a significant unlock for my lab.

AntonOresten avatar Nov 28 '25 17:11 AntonOresten