ChainRulesTestUtils crashes with a custom regularization function on a Flux Dense layer
I have been trying to write a custom reverse rule to a simple regularization function on a Flux Dense layer, and evaluate it with ChainRulesTestUtils. The function gradient from Zygote seems to work fine with the rules, but ChainRulesTestUtils crashes. The following code is executed just fine until the test_rrule calls. The first test_rrule tries to check whether the one-layer regularization function works, but instead it raises an error
Got exception outside of a @test
MethodError: no method matching zero(::typeof(tanh))
The second test_rrule crashes with
Got exception outside of a @test
return type Tuple{NoTangent, Tangent{Chain{Tuple{Dense{typeof(tanh), Matrix{Float64}, Vector{Float64}}, Dense{typeof(tanh), Matrix{Float64}, Vector{Float64}}}}, NamedTuple{(:layers,), Tuple{Tuple{NamedTuple{(:weight, :bias, :σ), Tuple{Matrix{Float64}, ZeroTangent, NoTangent}}, NamedTuple{(:weight, :bias, :σ), Tuple{Matrix{Float64}, ZeroTangent, NoTangent}}}}}}} does not match inferred return type Tuple{NoTangent, Tangent{Chain{Tuple{Dense{typeof(tanh), Matrix{Float64}, Vector{Float64}}, Dense{typeof(tanh), Matrix{Float64}, Vector{Float64}}}}}}
Any idea what could be the issue here? A bug somewhere?
using ChainRulesCore
using Flux
using Random
using ChainRulesTestUtils
Flux.trainable(nn::Dense) = (nn.weight, nn.bias,)
function weightregularization(nn::Dense)
return sum((nn.weight).^2.0)
end
function ChainRulesCore.rrule(::typeof(weightregularization), nn::Dense)
y = weightregularization(nn)
project_w = ProjectTo(nn.weight)
function weightregularization_pullback(ȳ)
pullb = Tangent{Dense}(weight=project_w(ȳ * 2.0*nn.weight), bias=ZeroTangent(), σ= NoTangent())
return NoTangent(), pullb
end
return y, weightregularization_pullback
end
function totalregularization(ch::Chain{T}) where T<:Tuple{Vararg{Dense}}
a = 0.0
for i in ch
a = a + sum(i.weight.^2.0)
end
return a
end
function ChainRulesCore.rrule(::typeof(totalregularization), ch::Chain{T}) where T<:Tuple{Vararg{Dense}}
y = totalregularization(ch)
function totalregularization_pullback(ȳ)
totalpullback = []
N = length(ch)
for i = 1:N
project_w = ProjectTo(ch[i].weight)
push!(totalpullback, (weight= project_w(ȳ * 2.0*ch[i].weight), bias = ZeroTangent(), σ= NoTangent()))
end
pullb = Tangent{Chain{T}}(layers=Tuple(totalpullback))
return NoTangent(), pullb
end
return y, totalregularization_pullback
end
nn = Dense(randn(1,2), randn(1), tanh)
gr1 = gradient(weightregularization,nn)
l1 = Dense(randn(2,2), randn(2), tanh)
l2 = Dense(randn(1,2), randn(1), tanh)
ch = Chain(l1,l2)
gr2 = gradient(totalregularization,ch)
test_rrule(weightregularization,nn)
test_rrule(totalregularization,ch)
Odds are as a work around you need to either implement FiniteDifferences.to_vec or maybe rand_tangent (https://github.com/JuliaDiff/ChainRulesTestUtils.jl/blob/6925da14c12e3d743c8d3620db8a8bee1433d5c3/src/rand_tangent.jl#L8)
for your type.
I would need to see the stack trace to know which.
Either this to_vec method
or this rand_tangent method is not smart enough to handle a struct where some fields are differentiable and others are not.
Probably the extra smartness needed is to know that for objects that have no fields (like nonclosure/functor functions), they are NoTangent()
Here's the stacktrace for test_rrule(weightregularization,nn):
Stacktrace:
[1] test_approx(::AbstractZero, x::Any, msg::Any; kwargs::Base.Pairs{Symbol, V, Tuple{Vararg{Symbol, N}}, NamedTuple{names, T}} where {V, N, names, T<:Tuple{Vararg{Any, N}}})
@ ChainRulesTestUtils ~/.julia/packages/ChainRulesTestUtils/YbVdW/src/check_result.jl:42
[2] test_approx(actual::Tangent{Dense, NamedTuple{(:weight, :bias, :σ), Tuple{Matrix{Float64}, ZeroTangent, NoTangent}}}, expected::Any, msg::Any; kwargs::Base.Pairs{Symbol, V, Tuple{Vararg{Symbol, N}}, NamedTuple{names, T}} where {V, N, names, T<:Tuple{Vararg{Any, N}}})
@ ChainRulesTestUtils ~/.julia/packages/ChainRulesTestUtils/YbVdW/src/check_result.jl:134
[3] _test_cotangent(accum_cotangent::Any, ad_cotangent::Any, fd_cotangent::Any; check_inferred::Any, kwargs::Base.Pairs{Symbol, V, Tuple{Vararg{Symbol, N}}, NamedTuple{names, T}} where {V, N, names, T<:Tuple{Vararg{Any, N}}})
@ ChainRulesTestUtils ~/.julia/packages/ChainRulesTestUtils/YbVdW/src/testers.jl:299
[4] (::ChainRulesTestUtils.var"#60#64"{Bool, NamedTuple{(:rtol, :atol), Tuple{Float64, Float64}}})(::Any, ::Vararg{Any})
@ ChainRulesTestUtils ~/.julia/packages/ChainRulesTestUtils/YbVdW/src/testers.jl:226
[5] (::Base.var"#59#60"{ChainRulesTestUtils.var"#60#64"{Bool, NamedTuple{(:rtol, :atol), Tuple{Float64, Float64}}}})(#unused#::Nothing, xs::Tuple{Tangent{Dense{typeof(tanh), Matrix{Float64}, Vector{Float64}}, NamedTuple{(:weight, :bias, :σ), Tuple{Matrix{Float64}, Vector{Float64}, NoTangent}}}, Tangent{Dense, NamedTuple{(:weight, :bias, :σ), Tuple{Matrix{Float64}, ZeroTangent, NoTangent}}}, Dense{typeof(tanh), Matrix{Float64}, Vector{Float64}}})
@ Base ./tuple.jl:556
[6] BottomRF
@ ./reduce.jl:81 [inlined]
[7] _foldl_impl(op::Base.BottomRF{Base.var"#59#60"{ChainRulesTestUtils.var"#60#64"{Bool, NamedTuple{(:rtol, :atol), Tuple{Float64, Float64}}}}}, init::Nothing, itr::Base.Iterators.Zip{Tuple{Tuple{NoTangent, Tangent{Dense{typeof(tanh), Matrix{Float64}, Vector{Float64}}, NamedTuple{(:weight, :bias, :σ), Tuple{Matrix{Float64}, Vector{Float64}, NoTangent}}}}, Tuple{NoTangent, Tangent{Dense, NamedTuple{(:weight, :bias, :σ), Tuple{Matrix{Float64}, ZeroTangent, NoTangent}}}}, Tuple{NoTangent, Dense{typeof(tanh), Matrix{Float64}, Vector{Float64}}}}})
@ Base ./reduce.jl:62
[8] foldl_impl(op::Base.BottomRF{Base.var"#59#60"{ChainRulesTestUtils.var"#60#64"{Bool, NamedTuple{(:rtol, :atol), Tuple{Float64, Float64}}}}}, nt::Nothing, itr::Base.Iterators.Zip{Tuple{Tuple{NoTangent, Tangent{Dense{typeof(tanh), Matrix{Float64}, Vector{Float64}}, NamedTuple{(:weight, :bias, :σ), Tuple{Matrix{Float64}, Vector{Float64}, NoTangent}}}}, Tuple{NoTangent, Tangent{Dense, NamedTuple{(:weight, :bias, :σ), Tuple{Matrix{Float64}, ZeroTangent, NoTangent}}}}, Tuple{NoTangent, Dense{typeof(tanh), Matrix{Float64}, Vector{Float64}}}}})
@ Base ./reduce.jl:48
[9] mapfoldl_impl(f::typeof(identity), op::Base.var"#59#60"{ChainRulesTestUtils.var"#60#64"{Bool, NamedTuple{(:rtol, :atol), Tuple{Float64, Float64}}}}, nt::Nothing, itr::Base.Iterators.Zip{Tuple{Tuple{NoTangent, Tangent{Dense{typeof(tanh), Matrix{Float64}, Vector{Float64}}, NamedTuple{(:weight, :bias, :σ), Tuple{Matrix{Float64}, Vector{Float64}, NoTangent}}}}, Tuple{NoTangent, Tangent{Dense, NamedTuple{(:weight, :bias, :σ), Tuple{Matrix{Float64}, ZeroTangent, NoTangent}}}}, Tuple{NoTangent, Dense{typeof(tanh), Matrix{Float64}, Vector{Float64}}}}})
@ Base ./reduce.jl:44
[10] mapfoldl(f::Function, op::Function, itr::Base.Iterators.Zip{Tuple{Tuple{NoTangent, Tangent{Dense{typeof(tanh), Matrix{Float64}, Vector{Float64}}, NamedTuple{(:weight, :bias, :σ), Tuple{Matrix{Float64}, Vector{Float64}, NoTangent}}}}, Tuple{NoTangent, Tangent{Dense, NamedTuple{(:weight, :bias, :σ), Tuple{Matrix{Float64}, ZeroTangent, NoTangent}}}}, Tuple{NoTangent, Dense{typeof(tanh), Matrix{Float64}, Vector{Float64}}}}}; init::Nothing)
@ Base ./reduce.jl:170
[11] #foldl#260
@ ./reduce.jl:193 [inlined]
[12] foreach(::Function, ::Tuple{NoTangent, Tangent{Dense{typeof(tanh), Matrix{Float64}, Vector{Float64}}, NamedTuple{(:weight, :bias, :σ), Tuple{Matrix{Float64}, Vector{Float64}, NoTangent}}}}, ::Tuple{NoTangent, Tangent{Dense, NamedTuple{(:weight, :bias, :σ), Tuple{Matrix{Float64}, ZeroTangent, NoTangent}}}}, ::Tuple{NoTangent, Dense{typeof(tanh), Matrix{Float64}, Vector{Float64}}})
@ Base ./tuple.jl:556
[13] macro expansion
@ ~/.julia/packages/ChainRulesTestUtils/YbVdW/src/testers.jl:225 [inlined]
[14] macro expansion
@ /opt/julia/share/julia/stdlib/v1.8/Test/src/Test.jl:1360 [inlined]
[15] test_rrule(config::RuleConfig, f::Any, args::Any; output_tangent::Any, check_thunked_output_tangent::Any, fdm::Any, rrule_f::Any, check_inferred::Bool, fkwargs::NamedTuple, rtol::Real, atol::Real, kwargs::Base.Pairs{Symbol, V, Tuple{Vararg{Symbol, N}}, NamedTuple{names, T}} where {V, N, names, T<:Tuple{Vararg{Any, N}}})
@ ChainRulesTestUtils ~/.julia/packages/ChainRulesTestUtils/YbVdW/src/testers.jl:196
[16] test_rrule(config::RuleConfig, f::Any, args::Any)
@ ChainRulesTestUtils ~/.julia/packages/ChainRulesTestUtils/YbVdW/src/testers.jl:173
[17] test_rrule(::Any, ::Vararg{Any}; kwargs::Base.Pairs{Symbol, V, Tuple{Vararg{Symbol, N}}, NamedTuple{names, T}} where {V, N, names, T<:Tuple{Vararg{Any, N}}})
@ ChainRulesTestUtils ~/.julia/packages/ChainRulesTestUtils/YbVdW/src/testers.jl:170
[18] test_rrule(::Any, ::Any)
@ ChainRulesTestUtils ~/.julia/packages/ChainRulesTestUtils/YbVdW/src/testers.jl:169
[19] top-level scope
@ ~/Julia/customreg.jl:61
and here's the stacktrace for test_rrule(totalregularization,ch):
Stacktrace:
[1] error(s::String)
@ Base ./error.jl:35
[2] _test_inferred(f::Any, args::Any; kwargs::Base.Pairs{Symbol, V, Tuple{Vararg{Symbol, N}}, NamedTuple{names, T}} where {V, N, names, T<:Tuple{Vararg{Any, N}}})
@ ChainRulesTestUtils ~/.julia/packages/ChainRulesTestUtils/YbVdW/src/testers.jl:255
[3] _test_inferred
@ ~/.julia/packages/ChainRulesTestUtils/YbVdW/src/testers.jl:253 [inlined]
[4] macro expansion
@ ~/.julia/packages/ChainRulesTestUtils/YbVdW/src/testers.jl:211 [inlined]
[5] macro expansion
@ /opt/julia/share/julia/stdlib/v1.8/Test/src/Test.jl:1360 [inlined]
[6] test_rrule(config::RuleConfig, f::Any, args::Any; output_tangent::Any, check_thunked_output_tangent::Any, fdm::Any, rrule_f::Any, check_inferred::Bool, fkwargs::NamedTuple, rtol::Real, atol::Real, kwargs::Base.Pairs{Symbol, V, Tuple{Vararg{Symbol, N}}, NamedTuple{names, T}} where {V, N, names, T<:Tuple{Vararg{Any, N}}})
@ ChainRulesTestUtils ~/.julia/packages/ChainRulesTestUtils/YbVdW/src/testers.jl:196
[7] test_rrule(config::RuleConfig, f::Any, args::Any)
@ ChainRulesTestUtils ~/.julia/packages/ChainRulesTestUtils/YbVdW/src/testers.jl:173
[8] test_rrule(::Any, ::Vararg{Any}; kwargs::Base.Pairs{Symbol, V, Tuple{Vararg{Symbol, N}}, NamedTuple{names, T}} where {V, N, names, T<:Tuple{Vararg{Any, N}}})
@ ChainRulesTestUtils ~/.julia/packages/ChainRulesTestUtils/YbVdW/src/testers.jl:170
[9] test_rrule(::Any, ::Any)
@ ChainRulesTestUtils ~/.julia/packages/ChainRulesTestUtils/YbVdW/src/testers.jl:169
[10] top-level scope
@ ~/Julia/customreg.jl:62
So the problem with the second one is just that it isn't inferrable. So ChainRulesTestUtils is working correctly there.
Tuple(totalpullback)) is not an inferable operation, because it can't tell how long the Tuple will be.
Rather than writing it with a for loop + push!ing to a vector, consider using the ntuple function or map (on a Tuple) -- that tends to be inferrable.
Or you can use test_rrule(...; check_inferred=false)
For the first case what is happening seems to to be that finite differencing is returning a result that contans a tanh element.
Which is weird.
Something is going wrong here: https://github.com/JuliaDiff/ChainRulesTestUtils.jl/blob/6925da14c12e3d743c8d3620db8a8bee1433d5c3/src/finite_difference_calls.jl#L43-L58
But I have no idea what.
FiniteDifferences.to_vec looks right
julia> nn = Dense(randn(1,2), randn(1), tanh)
Dense(2 => 1, tanh) # 3 parameters
julia> FiniteDifferences.to_vec(nn)
([0.6512775053740293, 0.09461593565834911, 2.580817436122576], FiniteDifferences.var"#structtype_from_vec#29"{Dense{typeof(tanh), Matrix{Float64}, Vector{Float64}}, FiniteDifferences.var"#Tuple_from_vec#52"{Tuple{Int64, Int64, Int64}, Tuple{Int64, Int64, Int64}, Tuple{typeof(identity), typeof(identity), typeof(identity)}}, Tuple{FiniteDifferences.var"#Array_from_vec#34"{Matrix{Float64}, typeof(identity)}, typeof(identity), FiniteDifferences.var"#24#27"{typeof(tanh)}}}(FiniteDifferences.var"#Tuple_from_vec#52"{Tuple{Int64, Int64, Int64}, Tuple{Int64, Int64, Int64}, Tuple{typeof(identity), typeof(identity), typeof(identity)}}((2, 3, 3), (2, 1, 0), (identity, identity, identity)), (FiniteDifferences.var"#Array_from_vec#34"{Matrix{Float64}, typeof(identity)}([0.6512775053740293 0.09461593565834911], identity), identity, FiniteDifferences.var"#24#27"{typeof(tanh)}(tanh))))