ForwardDiff.jl
ForwardDiff.jl copied to clipboard
Supporting copysign and flipsign
I have some code that contains usages of copysign and flipsign; rather than replacing these calls with sign-based equivalents I was hoping to add support for copysign and flipsign to ForwardDiff. Here's what I've tried:
Status quo on 0.7.0
| | |_| | | | (_| | | Version 0.7.0 (2018-08-08 06:46 UTC)
_/ |\__'_|_|_|\__'_| |
|__/ | x86_64-linux-gnu
julia> using ForwardDiff
julia> ForwardDiff.derivative(x -> flipsign(1., x), 1.)
ERROR: MethodError: no method matching Float64(::ForwardDiff.Dual{ForwardDiff.Tag{getfield(Main, Symbol("##3#4")),Float64},Float64,1})
Closest candidates are:
Float64(::Real, ::RoundingMode) where T<:AbstractFloat at rounding.jl:173
Float64(::T<:Number) where T<:Number at boot.jl:725
Float64(::Int8) at float.jl:60
...
Stacktrace:
[1] Float64(::ForwardDiff.Dual{ForwardDiff.Tag{getfield(Main, Symbol("##3#4")),Float64},Float64,1}) at ./deprecated.jl:468
[2] flipsign(::Float64, ::ForwardDiff.Dual{ForwardDiff.Tag{getfield(Main, Symbol("##3#4")),Float64},Float64,1}) at ./floatfuncs.jl:13
Hacking around the missing method
julia> Core.Float64(x::ForwardDiff.Dual) = convert(Float64, x)
julia> ForwardDiff.derivative(x -> flipsign(1., x), 1.)
ERROR: StackOverflowError:
Stacktrace:
[1] Float64(::ForwardDiff.Dual{ForwardDiff.Tag{getfield(Main, Symbol("##5#6")),Float64},Float64,1}) at ./REPL[4]:1
[2] convert(::Type{Float64}, ::ForwardDiff.Dual{ForwardDiff.Tag{getfield(Main, Symbol("##5#6")),Float64},Float64,1}) at ./number.jl:7
... (the last 2 lines are repeated 39998 more times)
[79999] Float64(::ForwardDiff.Dual{ForwardDiff.Tag{getfield(Main, Symbol("##5#6")),Float64},Float64,1}) at ./REPL[4]:1
Adding rules to DiffRules.jl
# within DiffRules.jl/src/rules.jl
@define_diffrule Base.copysign(x,y) = :( signbit($y) == signbit($x) ? one($x) : -one($x) ), :( zero($y) )
@define_diffrule Base.flipsign(x,y) = :( signbit($y) ? -one($x) : one($x) ), :( zero($y) )
yields
julia> ForwardDiff.derivative(x -> flipsign(1., x), 1.)
ERROR: MethodError: flipsign(::Float64, ::ForwardDiff.Dual{ForwardDiff.Tag{getfield(Main, Symbol("##3#4")),Float64},Float64,1}) is ambiguous. Candidates:
flipsign(x::AbstractFloat, y::ForwardDiff.Dual{Ty,V,N} where N where V<:Real) where Ty in ForwardDiff at /home/schmrlng/.julia/packages/ForwardDiff/OXtu9/src/dual.jl:114
flipsign(x::Real, y::ForwardDiff.Dual{Ty,V,N} where N where V<:Real) where Ty in ForwardDiff at /home/schmrlng/.julia/packages/ForwardDiff/OXtu9/src/dual.jl:114
flipsign(x::Float64, y::Real) in Base at floatfuncs.jl:13
Possible fix, define
flipsign(::Float64, ::ForwardDiff.Dual{Ty,V,N} where N where V<:Real)
Stacktrace:
[1] (::getfield(Main, Symbol("##3#4")))(::ForwardDiff.Dual{ForwardDiff.Tag{getfield(Main, Symbol("##3#4")),Float64},Float64,1}) at ./REPL[2]:1
[2] derivative(::getfield(Main, Symbol("##3#4")), ::Float64) at /home/schmrlng/.julia/packages/ForwardDiff/OXtu9/src/derivative.jl:14
It looks like ForwardDiff already has some method definitions designed to avoid ambiguities as above (see https://github.com/JuliaDiff/ForwardDiff.jl/blob/master/src/dual.jl?utf8=%E2%9C%93#L114); should we consider also adding the concrete FloatXX types to AMBIGUOUS_TYPES? Or is there some "better"-practices way for me to approach differentiating through the second argument of copysign and flipsign?