ChainRulesTestUtils.jl icon indicating copy to clipboard operation
ChainRulesTestUtils.jl copied to clipboard

ChainRulesTestUtils crashes with a custom regularization function on a Flux Dense layer

Open TPU22 opened this issue 3 years ago • 4 comments

I have been trying to write a custom reverse rule to a simple regularization function on a Flux Dense layer, and evaluate it with ChainRulesTestUtils. The function gradient from Zygote seems to work fine with the rules, but ChainRulesTestUtils crashes. The following code is executed just fine until the test_rrule calls. The first test_rrule tries to check whether the one-layer regularization function works, but instead it raises an error

Got exception outside of a @test 
MethodError: no method matching zero(::typeof(tanh))

The second test_rrule crashes with

Got exception outside of a @test
return type Tuple{NoTangent, Tangent{Chain{Tuple{Dense{typeof(tanh), Matrix{Float64}, Vector{Float64}}, Dense{typeof(tanh), Matrix{Float64}, Vector{Float64}}}}, NamedTuple{(:layers,), Tuple{Tuple{NamedTuple{(:weight, :bias, :σ), Tuple{Matrix{Float64}, ZeroTangent, NoTangent}}, NamedTuple{(:weight, :bias, :σ), Tuple{Matrix{Float64}, ZeroTangent, NoTangent}}}}}}} does not match inferred return type Tuple{NoTangent, Tangent{Chain{Tuple{Dense{typeof(tanh), Matrix{Float64}, Vector{Float64}}, Dense{typeof(tanh), Matrix{Float64}, Vector{Float64}}}}}}

Any idea what could be the issue here? A bug somewhere?


using ChainRulesCore
using Flux
using Random 
using ChainRulesTestUtils

Flux.trainable(nn::Dense) = (nn.weight, nn.bias,)

function weightregularization(nn::Dense)
    return sum((nn.weight).^2.0)

end

function ChainRulesCore.rrule(::typeof(weightregularization), nn::Dense)  
    y = weightregularization(nn)
    project_w = ProjectTo(nn.weight)
    function weightregularization_pullback(ȳ)
        pullb = Tangent{Dense}(weight=project_w(ȳ * 2.0*nn.weight),   bias=ZeroTangent(), σ= NoTangent())
        return NoTangent(), pullb
    end
    return y, weightregularization_pullback
end


function totalregularization(ch::Chain{T}) where T<:Tuple{Vararg{Dense}}
    a = 0.0
    for i in ch
        a = a + sum(i.weight.^2.0)
    end
    return a

end

function ChainRulesCore.rrule(::typeof(totalregularization), ch::Chain{T})  where T<:Tuple{Vararg{Dense}}
    y = totalregularization(ch)
    function totalregularization_pullback(ȳ)
        totalpullback = [] 
        N = length(ch)
        for i = 1:N
            project_w = ProjectTo(ch[i].weight)
            push!(totalpullback, (weight= project_w(ȳ * 2.0*ch[i].weight), bias = ZeroTangent(), σ= NoTangent()))
        end      
        pullb = Tangent{Chain{T}}(layers=Tuple(totalpullback))
        return NoTangent(), pullb
    end
    return y, totalregularization_pullback
end




nn = Dense(randn(1,2), randn(1), tanh)
gr1 = gradient(weightregularization,nn)


l1 = Dense(randn(2,2), randn(2), tanh)
l2 = Dense(randn(1,2), randn(1), tanh)
ch = Chain(l1,l2)
gr2 = gradient(totalregularization,ch)


test_rrule(weightregularization,nn) 
test_rrule(totalregularization,ch)

TPU22 avatar Nov 15 '22 17:11 TPU22

Odds are as a work around you need to either implement FiniteDifferences.to_vec or maybe rand_tangent (https://github.com/JuliaDiff/ChainRulesTestUtils.jl/blob/6925da14c12e3d743c8d3620db8a8bee1433d5c3/src/rand_tangent.jl#L8) for your type. I would need to see the stack trace to know which.

Either this to_vec method or this rand_tangent method is not smart enough to handle a struct where some fields are differentiable and others are not.

Probably the extra smartness needed is to know that for objects that have no fields (like nonclosure/functor functions), they are NoTangent()

oxinabox avatar Nov 21 '22 14:11 oxinabox

Here's the stacktrace for test_rrule(weightregularization,nn):

 Stacktrace:
    [1] test_approx(::AbstractZero, x::Any, msg::Any; kwargs::Base.Pairs{Symbol, V, Tuple{Vararg{Symbol, N}}, NamedTuple{names, T}} where {V, N, names, T<:Tuple{Vararg{Any, N}}})
      @ ChainRulesTestUtils ~/.julia/packages/ChainRulesTestUtils/YbVdW/src/check_result.jl:42
    [2] test_approx(actual::Tangent{Dense, NamedTuple{(:weight, :bias, :σ), Tuple{Matrix{Float64}, ZeroTangent, NoTangent}}}, expected::Any, msg::Any; kwargs::Base.Pairs{Symbol, V, Tuple{Vararg{Symbol, N}}, NamedTuple{names, T}} where {V, N, names, T<:Tuple{Vararg{Any, N}}})
      @ ChainRulesTestUtils ~/.julia/packages/ChainRulesTestUtils/YbVdW/src/check_result.jl:134
    [3] _test_cotangent(accum_cotangent::Any, ad_cotangent::Any, fd_cotangent::Any; check_inferred::Any, kwargs::Base.Pairs{Symbol, V, Tuple{Vararg{Symbol, N}}, NamedTuple{names, T}} where {V, N, names, T<:Tuple{Vararg{Any, N}}})
      @ ChainRulesTestUtils ~/.julia/packages/ChainRulesTestUtils/YbVdW/src/testers.jl:299
    [4] (::ChainRulesTestUtils.var"#60#64"{Bool, NamedTuple{(:rtol, :atol), Tuple{Float64, Float64}}})(::Any, ::Vararg{Any})
      @ ChainRulesTestUtils ~/.julia/packages/ChainRulesTestUtils/YbVdW/src/testers.jl:226
    [5] (::Base.var"#59#60"{ChainRulesTestUtils.var"#60#64"{Bool, NamedTuple{(:rtol, :atol), Tuple{Float64, Float64}}}})(#unused#::Nothing, xs::Tuple{Tangent{Dense{typeof(tanh), Matrix{Float64}, Vector{Float64}}, NamedTuple{(:weight, :bias, :σ), Tuple{Matrix{Float64}, Vector{Float64}, NoTangent}}}, Tangent{Dense, NamedTuple{(:weight, :bias, :σ), Tuple{Matrix{Float64}, ZeroTangent, NoTangent}}}, Dense{typeof(tanh), Matrix{Float64}, Vector{Float64}}})
      @ Base ./tuple.jl:556
    [6] BottomRF
      @ ./reduce.jl:81 [inlined]
    [7] _foldl_impl(op::Base.BottomRF{Base.var"#59#60"{ChainRulesTestUtils.var"#60#64"{Bool, NamedTuple{(:rtol, :atol), Tuple{Float64, Float64}}}}}, init::Nothing, itr::Base.Iterators.Zip{Tuple{Tuple{NoTangent, Tangent{Dense{typeof(tanh), Matrix{Float64}, Vector{Float64}}, NamedTuple{(:weight, :bias, :σ), Tuple{Matrix{Float64}, Vector{Float64}, NoTangent}}}}, Tuple{NoTangent, Tangent{Dense, NamedTuple{(:weight, :bias, :σ), Tuple{Matrix{Float64}, ZeroTangent, NoTangent}}}}, Tuple{NoTangent, Dense{typeof(tanh), Matrix{Float64}, Vector{Float64}}}}})
      @ Base ./reduce.jl:62
    [8] foldl_impl(op::Base.BottomRF{Base.var"#59#60"{ChainRulesTestUtils.var"#60#64"{Bool, NamedTuple{(:rtol, :atol), Tuple{Float64, Float64}}}}}, nt::Nothing, itr::Base.Iterators.Zip{Tuple{Tuple{NoTangent, Tangent{Dense{typeof(tanh), Matrix{Float64}, Vector{Float64}}, NamedTuple{(:weight, :bias, :σ), Tuple{Matrix{Float64}, Vector{Float64}, NoTangent}}}}, Tuple{NoTangent, Tangent{Dense, NamedTuple{(:weight, :bias, :σ), Tuple{Matrix{Float64}, ZeroTangent, NoTangent}}}}, Tuple{NoTangent, Dense{typeof(tanh), Matrix{Float64}, Vector{Float64}}}}})
      @ Base ./reduce.jl:48
    [9] mapfoldl_impl(f::typeof(identity), op::Base.var"#59#60"{ChainRulesTestUtils.var"#60#64"{Bool, NamedTuple{(:rtol, :atol), Tuple{Float64, Float64}}}}, nt::Nothing, itr::Base.Iterators.Zip{Tuple{Tuple{NoTangent, Tangent{Dense{typeof(tanh), Matrix{Float64}, Vector{Float64}}, NamedTuple{(:weight, :bias, :σ), Tuple{Matrix{Float64}, Vector{Float64}, NoTangent}}}}, Tuple{NoTangent, Tangent{Dense, NamedTuple{(:weight, :bias, :σ), Tuple{Matrix{Float64}, ZeroTangent, NoTangent}}}}, Tuple{NoTangent, Dense{typeof(tanh), Matrix{Float64}, Vector{Float64}}}}})
      @ Base ./reduce.jl:44
   [10] mapfoldl(f::Function, op::Function, itr::Base.Iterators.Zip{Tuple{Tuple{NoTangent, Tangent{Dense{typeof(tanh), Matrix{Float64}, Vector{Float64}}, NamedTuple{(:weight, :bias, :σ), Tuple{Matrix{Float64}, Vector{Float64}, NoTangent}}}}, Tuple{NoTangent, Tangent{Dense, NamedTuple{(:weight, :bias, :σ), Tuple{Matrix{Float64}, ZeroTangent, NoTangent}}}}, Tuple{NoTangent, Dense{typeof(tanh), Matrix{Float64}, Vector{Float64}}}}}; init::Nothing)
      @ Base ./reduce.jl:170
   [11] #foldl#260
      @ ./reduce.jl:193 [inlined]
   [12] foreach(::Function, ::Tuple{NoTangent, Tangent{Dense{typeof(tanh), Matrix{Float64}, Vector{Float64}}, NamedTuple{(:weight, :bias, :σ), Tuple{Matrix{Float64}, Vector{Float64}, NoTangent}}}}, ::Tuple{NoTangent, Tangent{Dense, NamedTuple{(:weight, :bias, :σ), Tuple{Matrix{Float64}, ZeroTangent, NoTangent}}}}, ::Tuple{NoTangent, Dense{typeof(tanh), Matrix{Float64}, Vector{Float64}}})
      @ Base ./tuple.jl:556
   [13] macro expansion
      @ ~/.julia/packages/ChainRulesTestUtils/YbVdW/src/testers.jl:225 [inlined]
   [14] macro expansion
      @ /opt/julia/share/julia/stdlib/v1.8/Test/src/Test.jl:1360 [inlined]
   [15] test_rrule(config::RuleConfig, f::Any, args::Any; output_tangent::Any, check_thunked_output_tangent::Any, fdm::Any, rrule_f::Any, check_inferred::Bool, fkwargs::NamedTuple, rtol::Real, atol::Real, kwargs::Base.Pairs{Symbol, V, Tuple{Vararg{Symbol, N}}, NamedTuple{names, T}} where {V, N, names, T<:Tuple{Vararg{Any, N}}})
      @ ChainRulesTestUtils ~/.julia/packages/ChainRulesTestUtils/YbVdW/src/testers.jl:196
   [16] test_rrule(config::RuleConfig, f::Any, args::Any)
      @ ChainRulesTestUtils ~/.julia/packages/ChainRulesTestUtils/YbVdW/src/testers.jl:173
   [17] test_rrule(::Any, ::Vararg{Any}; kwargs::Base.Pairs{Symbol, V, Tuple{Vararg{Symbol, N}}, NamedTuple{names, T}} where {V, N, names, T<:Tuple{Vararg{Any, N}}})
      @ ChainRulesTestUtils ~/.julia/packages/ChainRulesTestUtils/YbVdW/src/testers.jl:170
   [18] test_rrule(::Any, ::Any)
      @ ChainRulesTestUtils ~/.julia/packages/ChainRulesTestUtils/YbVdW/src/testers.jl:169
   [19] top-level scope
      @ ~/Julia/customreg.jl:61

and here's the stacktrace for test_rrule(totalregularization,ch):

Stacktrace:
    [1] error(s::String)
      @ Base ./error.jl:35
    [2] _test_inferred(f::Any, args::Any; kwargs::Base.Pairs{Symbol, V, Tuple{Vararg{Symbol, N}}, NamedTuple{names, T}} where {V, N, names, T<:Tuple{Vararg{Any, N}}})
      @ ChainRulesTestUtils ~/.julia/packages/ChainRulesTestUtils/YbVdW/src/testers.jl:255
    [3] _test_inferred
      @ ~/.julia/packages/ChainRulesTestUtils/YbVdW/src/testers.jl:253 [inlined]
    [4] macro expansion
      @ ~/.julia/packages/ChainRulesTestUtils/YbVdW/src/testers.jl:211 [inlined]
    [5] macro expansion
      @ /opt/julia/share/julia/stdlib/v1.8/Test/src/Test.jl:1360 [inlined]
    [6] test_rrule(config::RuleConfig, f::Any, args::Any; output_tangent::Any, check_thunked_output_tangent::Any, fdm::Any, rrule_f::Any, check_inferred::Bool, fkwargs::NamedTuple, rtol::Real, atol::Real, kwargs::Base.Pairs{Symbol, V, Tuple{Vararg{Symbol, N}}, NamedTuple{names, T}} where {V, N, names, T<:Tuple{Vararg{Any, N}}})
      @ ChainRulesTestUtils ~/.julia/packages/ChainRulesTestUtils/YbVdW/src/testers.jl:196
    [7] test_rrule(config::RuleConfig, f::Any, args::Any)
      @ ChainRulesTestUtils ~/.julia/packages/ChainRulesTestUtils/YbVdW/src/testers.jl:173
    [8] test_rrule(::Any, ::Vararg{Any}; kwargs::Base.Pairs{Symbol, V, Tuple{Vararg{Symbol, N}}, NamedTuple{names, T}} where {V, N, names, T<:Tuple{Vararg{Any, N}}})
      @ ChainRulesTestUtils ~/.julia/packages/ChainRulesTestUtils/YbVdW/src/testers.jl:170
    [9] test_rrule(::Any, ::Any)
      @ ChainRulesTestUtils ~/.julia/packages/ChainRulesTestUtils/YbVdW/src/testers.jl:169
   [10] top-level scope
      @ ~/Julia/customreg.jl:62

TPU22 avatar Nov 24 '22 09:11 TPU22

So the problem with the second one is just that it isn't inferrable. So ChainRulesTestUtils is working correctly there. Tuple(totalpullback)) is not an inferable operation, because it can't tell how long the Tuple will be. Rather than writing it with a for loop + push!ing to a vector, consider using the ntuple function or map (on a Tuple) -- that tends to be inferrable. Or you can use test_rrule(...; check_inferred=false)

oxinabox avatar Dec 08 '22 19:12 oxinabox

For the first case what is happening seems to to be that finite differencing is returning a result that contans a tanh element. Which is weird.

Something is going wrong here: https://github.com/JuliaDiff/ChainRulesTestUtils.jl/blob/6925da14c12e3d743c8d3620db8a8bee1433d5c3/src/finite_difference_calls.jl#L43-L58

But I have no idea what.

FiniteDifferences.to_vec looks right

julia> nn = Dense(randn(1,2), randn(1), tanh)
Dense(2 => 1, tanh)  # 3 parameters

julia> FiniteDifferences.to_vec(nn)
([0.6512775053740293, 0.09461593565834911, 2.580817436122576], FiniteDifferences.var"#structtype_from_vec#29"{Dense{typeof(tanh), Matrix{Float64}, Vector{Float64}}, FiniteDifferences.var"#Tuple_from_vec#52"{Tuple{Int64, Int64, Int64}, Tuple{Int64, Int64, Int64}, Tuple{typeof(identity), typeof(identity), typeof(identity)}}, Tuple{FiniteDifferences.var"#Array_from_vec#34"{Matrix{Float64}, typeof(identity)}, typeof(identity), FiniteDifferences.var"#24#27"{typeof(tanh)}}}(FiniteDifferences.var"#Tuple_from_vec#52"{Tuple{Int64, Int64, Int64}, Tuple{Int64, Int64, Int64}, Tuple{typeof(identity), typeof(identity), typeof(identity)}}((2, 3, 3), (2, 1, 0), (identity, identity, identity)), (FiniteDifferences.var"#Array_from_vec#34"{Matrix{Float64}, typeof(identity)}([0.6512775053740293 0.09461593565834911], identity), identity, FiniteDifferences.var"#24#27"{typeof(tanh)}(tanh))))

oxinabox avatar Dec 08 '22 20:12 oxinabox