SparseDiffTools.jl icon indicating copy to clipboard operation
SparseDiffTools.jl copied to clipboard

`autoback_hesvec` fails with `SparseMatrixCSC`

Open newalexander opened this issue 3 years ago • 1 comments

Hello! I can't seem to use autoback_hesvec when the function includes a sparse CSC matrix.

using Zygote, SparseArrays, SparseDiffTools

x, t = rand(Float32, 5), rand(Float32, 5)
A = sprand(Float32, 5, 5, 0.5)
loss(_x) = sum(tanh.(A * _x))

numback_hesvec(loss, x, t)  # works
autoback_hesvec(loss, x, t)  # fails

with the full message

ERROR: MethodError: no method matching Float64(::ForwardDiff.Dual{ForwardDiff.Tag{DataType, Float32}, Float64, 1})
Closest candidates are:
  (::Type{T})(::Real, ::RoundingMode) where T<:AbstractFloat at /Applications/Julia-1.7.app/Contents/Resources/julia/share/julia/base/rounding.jl:200
  (::Type{T})(::T) where T<:Number at /Applications/Julia-1.7.app/Contents/Resources/julia/share/julia/base/boot.jl:770
  (::Type{T})(::AbstractChar) where T<:Union{AbstractChar, Number} at /Applications/Julia-1.7.app/Contents/Resources/julia/share/julia/base/char.jl:50
  ...
Stacktrace:
  [1] convert(#unused#::Type{Float64}, x::ForwardDiff.Dual{ForwardDiff.Tag{DataType, Float32}, Float64, 1})
    @ Base ./number.jl:7
  [2] setindex!(A::Vector{Float64}, x::ForwardDiff.Dual{ForwardDiff.Tag{DataType, Float32}, Float64, 1}, i1::Int64)
    @ Base ./array.jl:903
  [3] (::ChainRulesCore.ProjectTo{SparseMatrixCSC, NamedTuple{(:element, :axes, :rowval, :nzranges, :colptr), Tuple{ChainRulesCore.ProjectTo{Float64, NamedTuple{(), Tuple{}}}, Tuple{Base.OneTo{Int64}, Base.OneTo{Int64}}, Vector{Int64}, Vector{UnitRange{Int64}}, Vector{Int64}}}})(dx::Matrix{ForwardDiff.Dual{ForwardDiff.Tag{DataType, Float32}, Float64, 1}})
    @ ChainRulesCore ~/.julia/packages/ChainRulesCore/uxrij/src/projection.jl:580
  [4] #1334
    @ ~/.julia/packages/ChainRules/3HAQW/src/rulesets/Base/arraymath.jl:36 [inlined]
  [5] unthunk
    @ ~/.julia/packages/ChainRulesCore/uxrij/src/tangent_types/thunks.jl:197 [inlined]
  [6] wrap_chainrules_output
    @ ~/.julia/packages/Zygote/FPUm3/src/compiler/chainrules.jl:104 [inlined]
  [7] map
    @ ./tuple.jl:223 [inlined]
  [8] wrap_chainrules_output
    @ ~/.julia/packages/Zygote/FPUm3/src/compiler/chainrules.jl:105 [inlined]
  [9] ZBack
    @ ~/.julia/packages/Zygote/FPUm3/src/compiler/chainrules.jl:204 [inlined]
 [10] Pullback
    @ ./REPL[5]:1 [inlined]
 [11] (::typeof(∂(loss)))(Δ::ForwardDiff.Dual{ForwardDiff.Tag{DataType, Float32}, Float64, 1})
    @ Zygote ~/.julia/packages/Zygote/FPUm3/src/compiler/interface2.jl:0
 [12] (::Zygote.var"#57#58"{typeof(∂(loss))})(Δ::ForwardDiff.Dual{ForwardDiff.Tag{DataType, Float32}, Float64, 1})
    @ Zygote ~/.julia/packages/Zygote/FPUm3/src/compiler/interface.jl:41
 [13] gradient(f::Function, args::Vector{ForwardDiff.Dual{ForwardDiff.Tag{DataType, Float32}, Float32, 1}})
    @ Zygote ~/.julia/packages/Zygote/FPUm3/src/compiler/interface.jl:76
 [14] (::SparseDiffTools.var"#78#79"{typeof(loss)})(x::Vector{ForwardDiff.Dual{ForwardDiff.Tag{DataType, Float32}, Float32, 1}})
    @ SparseDiffTools ~/.julia/packages/SparseDiffTools/9lSLn/src/differentiation/jaches_products_zygote.jl:39
 [15] autoback_hesvec(f::Function, x::Vector{Float32}, v::Vector{Float32})
    @ SparseDiffTools ~/.julia/packages/SparseDiffTools/9lSLn/src/differentiation/jaches_products_zygote.jl:41
 [16] top-level scope
    @ REPL[7]:1

This is with [47a9eef4] SparseDiffTools v1.20.0 and [082447d4] ChainRules v1.26.0. The issue was originally noticed here, where it was suggested to be reported here instead.

newalexander avatar Feb 12 '22 13:02 newalexander

[3] (::ChainRulesCore.ProjectTo{SparseMatrixCSC, NamedTuple{(:element, :axes, :rowval, :nzranges, :colptr), Tuple{ChainRulesCore.ProjectTo{Float64, NamedTuple{(), Tuple{}}}, Tuple{Base.OneTo{Int64}, Base.OneTo{Int64}}, Vector{Int64}, Vector{UnitRange{Int64}}, Vector{Int64}}}})(dx::Matrix{ForwardDiff.Dual{ForwardDiff.Tag{DataType, Float64}, Float64, 1}}) @ ChainRulesCore ~/.julia/packages/ChainRulesCore/uxrij/src/projection.jl:580

It seems like it's creating a SparseMatrixCSC of Float64. @DhairyaLGandhi how would I investigate why?

ChrisRackauckas avatar Mar 10 '22 06:03 ChrisRackauckas