NaNs when using ForwardDiff with RodriguesVec
For example:
julia> using ForwardDiff, Rotations
julia> ForwardDiff.derivative(0) do x
rotation_angle(RodriguesVec(x, 0, 0))
end
NaN
This comes out of the way the rotation angle is computed, which involves sqrt(rv.sx * rv.sx + ...), and which fails because the derivative of sqrt at 0 is inf (from above).
However, there is a correct answer for this derivative (it's 0), so I wonder if we can make the computation a bit more robust. For example, the generic LinearAlgebra.norm does the right thing:
julia> ForwardDiff.derivative(0.0) do x
rv = RodriguesVec(x, 0, 0)
norm([rv.sx, rv.sy, rv.sz])
end
0.0
For the generic norm, this happens because of this line:
https://github.com/JuliaLang/julia/blob/6da7aa8faf5cff49c462918a339435f4d703b999/stdlib/LinearAlgebra/src/generic.jl#L403
If we wanted to do this in Rotations, I think we'd similarly have to add a branch that explicitly checks for the case that all components are zero. This might have a small cost associated with it, but we'd have to benchmark; it might also be the case that explicitly checking for the case that the sum of the squared components allows LLVM to eliminate the branch that checks for negative inputs to sqrt and throws an error.
The explicit check for zero seems to not be too expensive:
julia> function f1(r)
sqrt(r.sx * r.sx + r.sy * r.sy + r.sz * r.sz)
end
f1 (generic function with 1 method)
julia> function f2(r)
norm((r.sx, r.sy, r.sz))
end
f2 (generic function with 1 method)
julia> function f3(r::RodriguesVec{T}) where {T}
if r.sx == 0 && r.sy == 0 && r.sz == 0
zero(T)
else
sqrt(r.sx * r.sx + r.sy * r.sy + r.sz * r.sz)
end
end
f3 (generic function with 1 method)
julia> r1 = RodriguesVec(0.5, 0, 0);
julia> r2 = RodriguesVec(0.0, 0, 0);
julia> using BenchmarkTools
julia> @btime f1($r1)
2.031 ns (0 allocations: 0 bytes)
0.5
julia> @btime f1($r2)
1.881 ns (0 allocations: 0 bytes)
0.0
julia> @btime f2($r1)
5.658 ns (0 allocations: 0 bytes)
0.5
julia> @btime f2($r2)
4.622 ns (0 allocations: 0 bytes)
0.0
julia> @btime f3($r1)
2.361 ns (0 allocations: 0 bytes)
0.5
julia> @btime f3($r2)
0.026 ns (0 allocations: 0 bytes)
0.0
It looks like both with f1 and f3 the negative-sqrt-argument (error) branch is being eliminated, so that's not an issue we need to worry about. I also tried
function f4(r::RodriguesVec{T}) where {T}
s = r.sx * r.sx + r.sy * r.sy + r.sz * r.sz
if s <= zero(T)
zero(T)
else
sqrt(s)
end
end
which generates slightly different code but seems to be about as fast as f3.
To completely eliminate the overhead, we could also make f3 or f4 the generic fallback and use the existing implementation for AbstractFloat.