[CUDNN] Support BFloat16
This PR defines methods for making cuDNN work with BFloat16s.BFloat16.
In the following example, I show how the new methods fixes the BFloat16 backward pass of Flux.logitcrossentropy:
Before
Note: Core.BFloat16 === BFloat16s.BFloat16, but I didn't explicitly import in this REPL session.
julia> x, y = CUDA.randn(Core.BFloat16, 32), CUDA.randn(Core.BFloat16, 32); Flux.gradient(x) do x
Flux.logitcrossentropy(x, y)
end
ERROR: MethodError: no method matching cudnnDataType(::Type{Core.BFloat16})
The function `cudnnDataType` exists, but no method is defined for this combination of argument types.
Closest candidates are:
cudnnDataType(::Type{Float16})
@ cuDNN ~/.julia/packages/cuDNN/vKsqU/src/util.jl:7
cudnnDataType(::Type{Float32})
@ cuDNN ~/.julia/packages/cuDNN/vKsqU/src/util.jl:8
cudnnDataType(::Type{Float64})
@ cuDNN ~/.julia/packages/cuDNN/vKsqU/src/util.jl:9
...
Stacktrace:
[1] cuDNN.cudnnTensorDescriptor(array::CuArray{Core.BFloat16, 4, CUDA.DeviceMemory}; format::cuDNN.cudnnTensorFormat_t, dims::Vector{Int32})
@ cuDNN ~/.julia/packages/cuDNN/vKsqU/src/tensor.jl:9
[2] cudnnSoftmaxForward!(y::CuArray{…}, x::CuArray{…}; o::@Kwargs{…})
@ cuDNN ~/.julia/packages/cuDNN/vKsqU/src/softmax.jl:17
[3] logsoftmax!(y::CuArray{Core.BFloat16, 1, CUDA.DeviceMemory}, x::CuArray{Core.BFloat16, 1, CUDA.DeviceMemory}; dims::Int64)
@ NNlibCUDACUDNNExt ~/.julia/packages/NNlib/1TYHL/ext/NNlibCUDACUDNNExt/softmax.jl:90
[4] logsoftmax!
@ ~/.julia/packages/NNlib/1TYHL/ext/NNlibCUDACUDNNExt/softmax.jl:87 [inlined]
[5] #logsoftmax#41
@ ~/.julia/packages/NNlib/1TYHL/ext/NNlibCUDACUDNNExt/softmax.jl:20 [inlined]
[6] logsoftmax
@ ~/.julia/packages/NNlib/1TYHL/ext/NNlibCUDACUDNNExt/softmax.jl:19 [inlined]
[7] #rrule#109
@ ~/.julia/packages/NNlib/1TYHL/src/softmax.jl:129 [inlined]
[8] rrule
@ ~/.julia/packages/NNlib/1TYHL/src/softmax.jl:128 [inlined]
[9] rrule
@ ~/.julia/packages/ChainRulesCore/Vsbj9/src/rules.jl:144 [inlined]
[10] chain_rrule_kw
@ ~/.julia/packages/Zygote/55SqB/src/compiler/chainrules.jl:246 [inlined]
[11] macro expansion
@ ~/.julia/packages/Zygote/55SqB/src/compiler/interface2.jl:-1 [inlined]
[12] _pullback
@ ~/.julia/packages/Zygote/55SqB/src/compiler/interface2.jl:81 [inlined]
[13] #logitcrossentropy#20
@ ~/.julia/packages/Flux/uRn8o/src/losses/functions.jl:272 [inlined]
[14] _pullback(::Zygote.Context{…}, ::Flux.Losses.var"##logitcrossentropy#20", ::Int64, ::typeof(mean), ::typeof(Flux.Losses.logitcrossentropy), ::CuArray{…}, ::CuArray{…})
@ Zygote ~/.julia/packages/Zygote/55SqB/src/compiler/interface2.jl:0
[15] _pullback(::Zygote.Context{…}, ::typeof(Flux.Losses.logitcrossentropy), ::CuArray{…}, ::CuArray{…})
@ Zygote ~/.julia/packages/Zygote/55SqB/src/compiler/interface2.jl:81
[16] #8
@ ./REPL[14]:2 [inlined]
[17] _pullback(ctx::Zygote.Context{false}, f::var"#8#9", args::CuArray{Core.BFloat16, 1, CUDA.DeviceMemory})
@ Zygote ~/.julia/packages/Zygote/55SqB/src/compiler/interface2.jl:0
[18] pullback(f::Function, cx::Zygote.Context{false}, args::CuArray{Core.BFloat16, 1, CUDA.DeviceMemory})
@ Zygote ~/.julia/packages/Zygote/55SqB/src/compiler/interface.jl:96
[19] pullback
@ ~/.julia/packages/Zygote/55SqB/src/compiler/interface.jl:94 [inlined]
[20] gradient(f::Function, args::CuArray{Core.BFloat16, 1, CUDA.DeviceMemory})
@ Zygote ~/.julia/packages/Zygote/55SqB/src/compiler/interface.jl:153
[21] #gradient#1
@ ~/.julia/packages/Flux/uRn8o/src/gradient.jl:44 [inlined]
[22] gradient(f::Function, args::CuArray{Core.BFloat16, 1, CUDA.DeviceMemory})
@ Flux ~/.julia/packages/Flux/uRn8o/src/gradient.jl:31
[23] top-level scope
@ REPL[14]:1
Some type information was truncated. Use `show(err)` to see complete types.
After defining cudnnDataType(::Type{BFloat16})
julia> x, y = CUDA.randn(Core.BFloat16, 32), CUDA.randn(Core.BFloat16, 32); Flux.gradient(x) do x
Flux.logitcrossentropy(x, y)
end
ERROR: Unknown tensor type Core.BFloat16
Stacktrace:
[1] error(s::String)
@ Base ./error.jl:44
[2] scalingParameter(T::Type, val::Int64)
@ cuDNN ~/.julia/packages/cuDNN/vKsqU/src/util.jl:34
[3] cudnnSoftmaxForwardWithDefaults(x::CuArray{…}; y::CuArray{…}, algo::cuDNN.cudnnSoftmaxAlgorithm_t, mode::cuDNN.cudnnSoftmaxMode_t, alpha::Int64, beta::Int64, format::cuDNN.cudnnTensorFormat_t, xDesc::cuDNN.cudnnTensorDescriptor, yDesc::cuDNN.cudnnTensorDescriptor)
@ cuDNN ~/.julia/packages/cuDNN/vKsqU/src/softmax.jl:34
[4] cudnnSoftmaxForward!(y::CuArray{…}, x::CuArray{…}; o::@Kwargs{…})
@ cuDNN ~/.julia/packages/cuDNN/vKsqU/src/softmax.jl:17
[5] logsoftmax!(y::CuArray{Core.BFloat16, 1, CUDA.DeviceMemory}, x::CuArray{Core.BFloat16, 1, CUDA.DeviceMemory}; dims::Int64)
@ NNlibCUDACUDNNExt ~/.julia/packages/NNlib/1TYHL/ext/NNlibCUDACUDNNExt/softmax.jl:90
[6] logsoftmax!
@ ~/.julia/packages/NNlib/1TYHL/ext/NNlibCUDACUDNNExt/softmax.jl:87 [inlined]
[7] #logsoftmax#41
@ ~/.julia/packages/NNlib/1TYHL/ext/NNlibCUDACUDNNExt/softmax.jl:20 [inlined]
[8] logsoftmax
@ ~/.julia/packages/NNlib/1TYHL/ext/NNlibCUDACUDNNExt/softmax.jl:19 [inlined]
[9] #rrule#109
@ ~/.julia/packages/NNlib/1TYHL/src/softmax.jl:129 [inlined]
[10] rrule
@ ~/.julia/packages/NNlib/1TYHL/src/softmax.jl:128 [inlined]
[11] rrule
@ ~/.julia/packages/ChainRulesCore/Vsbj9/src/rules.jl:144 [inlined]
[12] chain_rrule_kw
@ ~/.julia/packages/Zygote/55SqB/src/compiler/chainrules.jl:246 [inlined]
[13] macro expansion
@ ~/.julia/packages/Zygote/55SqB/src/compiler/interface2.jl:-1 [inlined]
[14] _pullback
@ ~/.julia/packages/Zygote/55SqB/src/compiler/interface2.jl:81 [inlined]
[15] #logitcrossentropy#20
@ ~/.julia/packages/Flux/uRn8o/src/losses/functions.jl:272 [inlined]
[16] _pullback(::Zygote.Context{…}, ::Flux.Losses.var"##logitcrossentropy#20", ::Int64, ::typeof(mean), ::typeof(Flux.Losses.logitcrossentropy), ::CuArray{…}, ::CuArray{…})
@ Zygote ~/.julia/packages/Zygote/55SqB/src/compiler/interface2.jl:0
[17] _pullback(::Zygote.Context{false}, ::typeof(Flux.Losses.logitcrossentropy), ::CuArray{Core.BFloat16, 1, CUDA.DeviceMemory}, ::CuArray{Core.BFloat16, 1, CUDA.DeviceMemory})
@ Zygote ~/.julia/packages/Zygote/55SqB/src/compiler/interface2.jl:81
[18] #11
@ ./REPL[19]:2 [inlined]
[19] _pullback(ctx::Zygote.Context{false}, f::var"#11#12", args::CuArray{Core.BFloat16, 1, CUDA.DeviceMemory})
@ Zygote ~/.julia/packages/Zygote/55SqB/src/compiler/interface2.jl:0
[20] pullback(f::Function, cx::Zygote.Context{false}, args::CuArray{Core.BFloat16, 1, CUDA.DeviceMemory})
@ Zygote ~/.julia/packages/Zygote/55SqB/src/compiler/interface.jl:96
[21] pullback
@ ~/.julia/packages/Zygote/55SqB/src/compiler/interface.jl:94 [inlined]
[22] gradient(f::Function, args::CuArray{Core.BFloat16, 1, CUDA.DeviceMemory})
@ Zygote ~/.julia/packages/Zygote/55SqB/src/compiler/interface.jl:153
[23] #gradient#1
@ ~/.julia/packages/Flux/uRn8o/src/gradient.jl:44 [inlined]
[24] gradient(f::Function, args::CuArray{Core.BFloat16, 1, CUDA.DeviceMemory})
@ Flux ~/.julia/packages/Flux/uRn8o/src/gradient.jl:31
[25] top-level scope
@ REPL[19]:1
Some type information was truncated. Use `show(err)` to see complete types.
After defining scalingParameter(::Type{BFloat16}, val)
julia> x, y = CUDA.randn(Core.BFloat16, 32), CUDA.randn(Core.BFloat16, 32); Flux.gradient(x) do x
Flux.logitcrossentropy(x, y)
end
(Core.BFloat16[0.19335938, 0.32226562, -0.23828125, -0.85546875, 0.953125, 0.12207031, 1.15625, -0.64453125, -0.103515625, 0.61328125 … 0.4453125, -1.203125, 1.0234375, -1.46875, 0.19628906, -0.87890625, -1.3203125, 1.515625, 0.6484375, 0.44921875],)
I also define a cptr method for consistency, but it appears the function isn't used anywhere.
Tests are added for softmax, activations, and pooling. I initially also tested convolutions, normalization, RNNs, and MHA but they don't appear to support BFloat16.
Adding BFloat16s.jl as a dependency does not affect compilation since it's already a dependency of CUDA.jl.
Along with my proposed fix in https://github.com/FluxML/Optimisers.jl/issues/215, this has allowed me to train LLMs in BFloat16 with Flux.jl in Julia v1.12. I am still tinkering with Optimisers.jl, but these together would be a significant unlock for my lab.