ForwardDiff.jl icon indicating copy to clipboard operation
ForwardDiff.jl copied to clipboard

Can't differentiatate through StepRangeLen due to TwicePrecision

Open howthebodyworks opened this issue 7 years ago • 7 comments

There is one very common data type that we can't differentiate through - StepRangeLen, which is what we usually get back from range.

julia> using ForwardDiff

julia> f1v(x::Vector) = sum(range(0; step=0.1, stop=2)*x[1]);  # edited not to clash with definitions below

julia> g1v = x -> ForwardDiff.gradient(f1v, x);

julia> f1v([1.0])
31.5

julia>  g1v([1.0])
MethodError: no method matching twiceprecision(::Base.TwicePrecision{ForwardDiff.Dual{ForwardDiff.Tag{typeof(f1),Float64},Float64,1}}, ::Int64)
Closest candidates are:
  twiceprecision(!Matched::T<:Union{Float16, Float32, Float64}, ::Integer) where T<:Union{Float16, Float32, Float64} at twiceprecision.jl:220
  twiceprecision(!Matched::Base.TwicePrecision{T<:Union{Float16, Float32, Float64}}, ::Integer) where T<:Union{Float16, Float32, Float64} at twiceprecision.jl:225

This problem does not arise with numerically equivalent data types which do not use TwicePrecision, such as LinRange:

julia> using ForwardDiff

julia> f2(x::Vector) = sum(LinRange(0, 2, 21)*x[1]);

julia> g2 = x -> ForwardDiff.gradient(f2, x);

julia> f2([1.0]), g2([1.0])
(21.0, [21.0])

Related: https://github.com/JuliaMath/Interpolations.jl/issues/293

howthebodyworks avatar Jan 14 '19 01:01 howthebodyworks

To be clear, we get the same problem from the more usual start:step:end syntax.

using ForwardDiff
f1 = x::Vector -> sum((0.5:0.1:2.5)*x[1]);
g1 = x -> ForwardDiff.gradient(f1, x);
f1([0.5]), g1([0.5])

danmackinlay avatar Feb 04 '19 09:02 danmackinlay

Interestingly you get a different error if you do the more-or-less-equivalent

using ForwardDiff
f1 = x::Vector -> sum((0.5*x[1]):(0.1*x[1]):(2.5*x[1]));
g1 = x -> ForwardDiff.gradient(f1, x);
f1([0.5]), g1([0.5])

This throws

MethodError: no method matching Integer(::ForwardDiff.Dual{ForwardDiff.Tag{getfield(Main, Symbol("##19#20")),Float64},Float64,1})
Closest candidates are:
  Integer(::T<:Number) where T<:Number at boot.jl:741
  Integer(!Matched::Integer) at boot.jl:765
  Integer(!Matched::Union{Float32, Float64}) at boot.jl:766
  ...

danmackinlay avatar Feb 04 '19 09:02 danmackinlay

Just bumped into this issue myself; just commenting to raise awareness.

tom-plaa avatar Dec 12 '22 23:12 tom-plaa

If you do range(0, Dual(1,1), length=10)you get a LinRange, which should work fine. Ideally most ranges with Dual numbers would make one, this is much simpler than StepRangeLen.

  • The error case above seems to be @less Dual(1,1) * (0:0.1:1). It should be easy to add a method here for that. (0:0.1:1) / Dual(1,1) is similar (and nearby).

  • 0:Dual(1,1):4 tries to make a StepRange and fails, it should make a LinRange.

  • You can construct and index, but not show, this range. It should work just like /.

julia> r = (0:0.1:1) ./ Dual(1,1);

julia> sum(r), r[2]
(Dual{Nothing}(5.5,-5.5), Dual{Nothing}(0.1,-0.1))

julia> typeof(r)
StepRangeLen{Dual{Nothing, Float64, 1}, Base.TwicePrecision{Dual{Nothing, Float64, 1}}, Base.TwicePrecision{Dual{Nothing, Float64, 1}}, Int64}

julia> r
Error showing value of type StepRangeLen{Dual{Nothing, Float64, 1}, Base.TwicePrecision{Dual{Nothing, Float64, 1}}, Base.TwicePrecision{Dual{Nothing, Float64, 1}}, Int64}:
ERROR: MethodError: Dual{Nothing, Float64, 1}(::Base.TwicePrecision{Dual{Nothing, Float64, 1}}) is ambiguous.

Candidates:
  (::Type{T})(x::Base.TwicePrecision) where T<:Number
    @ Base twiceprecision.jl:266
  Dual{T, V, N}(x) where {T, V, N}
    @ ForwardDiff ~/.julia/packages/ForwardDiff/pDtsf/src/dual.jl:77

Possible fix, define
  Dual{T, V, N}(::Base.TwicePrecision) where {T, V, N}

Fixing any or all of these would be a nice first PR.

mcabbott avatar Dec 13 '22 00:12 mcabbott

To make sure I understood, is your suggestion to make it default to a LinRange when trying to go through a StepRange using a Dual argument? Following the comments above, that's what I am currently trying to do in my code: replacing a StepRange object by a LinRange in a function that needs to be autodiff'ed. I was under the assumption that the desired behaviour would be to preserve the StepRange because of the floating point corrections.

tom-plaa avatar Dec 13 '22 12:12 tom-plaa

StepRange is the one for integers, like 1:2:10. But my example is pretty artificial & probably won't happen in the wild.

Yes I'm suggesting that for most operations which need new methods involving Duals it would be simplest to avoid StepRangeLen. I doubt it's worth trying to make Dual & TwicePrecision work together better. This wouldn't change the clever things done on construction without Duals -- inferring (0:0.1:1).step to end exactly at 1 is nontrivial.

Explicitly writing LinRange in your code may hack around this issue, but isn't what I'm suggesting, but may help. But not so hard to just fix it for everyone.

julia> using ForwardDiff: Dual

julia> Base.:*(x::Dual, r::StepRangeLen{<:Real,<:Base.TwicePrecision}) =
           StepRangeLen(x*r.ref, x*r.step, length(r), r.offset)  # one modification

julia> Dual(1,1) * (0:0.1:1) |> collect  # show hits ambiguity above, collect does not
11-element Vector{Dual{Nothing, Float64, 1}}:
 Dual{Nothing}(0.0,0.0)
 Dual{Nothing}(0.1,0.1)
 Dual{Nothing}(0.2,0.2)
 Dual{Nothing}(0.30000000000000004,0.30000000000000004)
 Dual{Nothing}(0.4,0.4)
 Dual{Nothing}(0.5,0.5)
 Dual{Nothing}(0.6000000000000001,0.6000000000000001)
 Dual{Nothing}(0.7000000000000001,0.7000000000000001)
 Dual{Nothing}(0.8,0.8)
 Dual{Nothing}(0.9,0.9)
 Dual{Nothing}(1.0,1.0)

julia> Base.:*(x::Dual, r::StepRangeLen{<:Real,<:Base.TwicePrecision}) =
                  LinRange(x*first(r), x*last(r), length(r))  # alternative

julia> Dual(1,1) * (0:0.1:1) 
11-element LinRange{Dual{Nothing, Float64, 1}, Int64}:
 Dual{Nothing}(0.0,0.0), Dual{Nothing}(0.1,0.1), …, Dual{Nothing}(1.0,1.0)

julia> collect(ans)
11-element Vector{Dual{Nothing, Float64, 1}}:
 Dual{Nothing}(0.0,0.0)
 Dual{Nothing}(0.1,0.1)
 Dual{Nothing}(0.2,0.2)
 Dual{Nothing}(0.3,0.3)
 Dual{Nothing}(0.4,0.4)
 Dual{Nothing}(0.5,0.5)
 Dual{Nothing}(0.6,0.6)
 Dual{Nothing}(0.7,0.7)
 Dual{Nothing}(0.8,0.8)
 Dual{Nothing}(0.9,0.9)
 Dual{Nothing}(1.0,1.0)

mcabbott avatar Dec 13 '22 14:12 mcabbott

Got it, thanks. I may try to take a shot at defining those when I manage to have some free time, then.

tom-plaa avatar Dec 13 '22 14:12 tom-plaa