Can't differentiate an ODE solver due to lack of isnan and other type errors.
TaylorDiff.jl seems to throw an error when I try to differentiate a fairly simple ODE solver. There is an error on the MTK side, but even after the workaround there (See https://discourse.julialang.org/t/error-trying-to-forwarddiff-through-an-ode-solver/114339/6) I can't get Taylor diff.jl to work.
MWE:
using ModelingToolkit, DifferentialEquations
using TaylorDiff, ForwardDiff
using DifferentiationInterface, Enzyme, Zygote, ReverseDiff
using SciMLSensitivity
#import Base.isnan
#function isnan(x::TaylorScalar{Float64, 2})
# return false
#end
function ADTest()
@parameters a
@variables t x1(t)
D = Differential(t)
states = [x1]
parameters = [a]
@named pre_model = ODESystem([D(x1) ~ a * x1], t, states, parameters)
model = structural_simplify(pre_model)
ic = Dict(x1 => 1.0)
p_true = Dict(a => 2.0)
problem = ODEProblem{true, SciMLBase.FullSpecialize}(model, ic, [0.0, 1.0], p_true)
soln = ModelingToolkit.solve(problem, Tsit5(), abstol = 1e-12, reltol = 1e-12)
display(soln(0.5, idxs = [x1]))
function different_time(new_ic, new_params, new_t)
#newprob = ODEProblem{true, SciMLBase.FullSpecialize}(model, new_ic, [0.0, new_t*2], new_params)
#newprob = remake(problem, u0=new_ic, tspan = [0.0, new_t], p = new_params)
newprob = remake(problem, u0 = new_ic, tspan = [0.0, new_t], p=new_params)
newprob = remake(newprob, u0 = typeof(new_t).(newprob.u0))
new_soln = ModelingToolkit.solve(newprob, Tsit5(), abstol = 1e-12, reltol = 1e-12)
return (soln(new_t, idxs = [x1]))
end
function just_t(new_t)
return different_time(ic, p_true, new_t)[1]
end
display(different_time(ic, p_true, 2e-5))
display(just_t(0.5))
#display(ForwardDiff.derivative(just_t,1.0))
display(TaylorDiff.derivative(just_t,1.0,1)) #isnan error
#display(value_and_gradient(just_t, AutoForwardDiff(), 1.0))
#display(value_and_gradient(just_t, AutoReverseDiff(), 1.0))
#display(value_and_gradient(just_t, AutoEnzyme(Enzyme.Reverse), 1.0))
#display(value_and_gradient(just_t, AutoEnzyme(Enzyme.Forward), 1.0))
#display(value_and_gradient(just_t, AutoZygote(), 1.0))
end
ADTest()
Running the above, the error is
ERROR: LoadError: MethodError: no method matching isnan(::TaylorScalar{Float64, 2})
Closest candidates are:
isnan(::Missing)
@ Base missing.jl:101
isnan(::BigFloat)
@ Base mpfr.jl:982
isnan(::Complex)
@ Base complex.jl:151
...
Stacktrace:
[1] _any(f::typeof(isnan), itr::Tuple{TaylorScalar{Float64, 2}, TaylorScalar{Float64, 2}}, ::Colon)
@ Base ./reduce.jl:1220
[2] any(f::Function, itr::Tuple{TaylorScalar{Float64, 2}, TaylorScalar{Float64, 2}})
@ Base ./reduce.jl:1235
[3] get_concrete_tspan(prob::ODEProblem{…}, isadapt::Bool, kwargs::@Kwargs{…}, p::ModelingToolkit.MTKParameters{…})
@ DiffEqBase ~/.julia/packages/DiffEqBase/X5SZr/src/solve.jl:1287
[4] get_concrete_problem(prob::ODEProblem{…}, isadapt::Bool; kwargs::@Kwargs{…})
@ DiffEqBase ~/.julia/packages/DiffEqBase/X5SZr/src/solve.jl:1169
[5] solve_up(prob::ODEProblem{…}, sensealg::Nothing, u0::Vector{…}, p::ModelingToolkit.MTKParameters{…}, args::Tsit5{…}; kwargs::@Kwargs{…})
@ DiffEqBase ~/.julia/packages/DiffEqBase/X5SZr/src/solve.jl:1074
[6] solve(prob::ODEProblem{…}, args::Tsit5{…}; sensealg::Nothing, u0::Nothing, p::Nothing, wrap::Val{…}, kwargs::@Kwargs{…})
@ DiffEqBase ~/.julia/packages/DiffEqBase/X5SZr/src/solve.jl:1003
[7] (::var"#different_time#1"{ODESolution{…}, ODEProblem{…}, Num})(new_ic::Dict{Num, Float64}, new_params::Dict{Num, Float64}, new_t::TaylorScalar{Float64, 2})
@ Main ~/learning/ODETests/PLI/MWE3.jl:32
[8] (::var"#just_t#2"{var"#different_time#1"{ODESolution{…}, ODEProblem{…}, Num}, Dict{Num, Float64}, Dict{Num, Float64}})(new_t::TaylorScalar{Float64, 2})
@ Main ~/learning/ODETests/PLI/MWE3.jl:37
[9] derivative
@ ~/.julia/packages/TaylorDiff/zNnz2/src/derivative.jl:28 [inlined]
[10] derivative
@ ~/.julia/packages/TaylorDiff/zNnz2/src/derivative.jl:18 [inlined]
[11] ADTest()
@ Main ~/learning/ODETests/PLI/MWE3.jl:44
[12] top-level scope
@ ~/learning/ODETests/PLI/MWE3.jl:53
[13] include(fname::String)
@ Base.MainInclude ./client.jl:489
[14] top-level scope
@ REPL[1]:1
If I go ahead and try to define isnan, (you can uncomment 4 lines near the top of the MWE), the error becomes
ERROR: LoadError: Non-concrete element type inside of an `Array` detected.
Arrays with non-concrete element types, such as
`Array{Union{Float32,Float64}}`, are not supported by the
differential equation solvers. Anyways, this is bad for
performance so you don't want to be doing this!
If this was a mistake, promote the element types to be
all the same. If this was intentional, for example,
using Unitful.jl with different unit values, then use
an array type which has fast broadcast support for
heterogeneous values such as the ArrayPartition
from RecursiveArrayTools.jl. For example:
```julia
using RecursiveArrayTools
x = ArrayPartition([1.0,2.0],[1f0,2f0])
y = ArrayPartition([3.0,4.0],[3f0,4f0])
x .+ y # fast, stable, and usable as u0 into DiffEq!
Element type: Any
Some of the types have been truncated in the stacktrace for improved reading. To emit complete information
in the stack trace, evaluate TruncatedStacktraces.VERBOSE[] = true and re-run the code.
Stacktrace:
[1] solve_call(_prob::ODEProblem{…}, args::Tsit5{…}; merge_callbacks::Bool, kwargshandle::Nothing, kwargs::@Kwargs{…})
@ DiffEqBase ~/.julia/packages/DiffEqBase/X5SZr/src/solve.jl:592
[2] solve_up(prob::ODEProblem{…}, sensealg::Nothing, u0::Vector{…}, p::ModelingToolkit.MTKParameters{…}, args::Tsit5{…}; kwargs::@Kwargs{…})
@ DiffEqBase ~/.julia/packages/DiffEqBase/X5SZr/src/solve.jl:1080
[3] solve(prob::ODEProblem{…}, args::Tsit5{…}; sensealg::Nothing, u0::Nothing, p::Nothing, wrap::Val{…}, kwargs::@Kwargs{…})
@ DiffEqBase ~/.julia/packages/DiffEqBase/X5SZr/src/solve.jl:1003
[4] (::var"#different_time#3"{ODESolution{…}, ODEProblem{…}, Num})(new_ic::Dict{Num, Float64}, new_params::Dict{Num, Float64}, new_t::TaylorScalar{Float64, 2})
@ Main ~/learning/ODETests/PLI/MWE3.jl:32
[5] (::var"#just_t#4"{var"#different_time#3"{ODESolution{…}, ODEProblem{…}, Num}, Dict{Num, Float64}, Dict{Num, Float64}})(new_t::TaylorScalar{Float64, 2})
@ Main ~/learning/ODETests/PLI/MWE3.jl:37
[6] derivative
@ ~/.julia/packages/TaylorDiff/zNnz2/src/derivative.jl:28 [inlined]
[7] derivative
@ ~/.julia/packages/TaylorDiff/zNnz2/src/derivative.jl:18 [inlined]
[8] ADTest()
@ Main ~/learning/ODETests/PLI/MWE3.jl:44
[9] top-level scope
@ ~/learning/ODETests/PLI/MWE3.jl:53
[10] include(fname::String)
@ Base.MainInclude ./client.jl:489
[11] top-level scope
@ REPL[1]:1
in expression starting at /home/orebas/learning/ODETests/PLI/MWE3.jl:53
Some type information was truncated. Use show(err) to see complete types.
This is identified previously: #35, due to the type system inconsistency issues. Unfortunately I haven't figured out a good way to handle this...
Ok I now believe not <: Real is a design error and needs to be fixed. I initiated a fix at https://github.com/JuliaDiff/TaylorDiff.jl/tree/subtype-number , when it is done you will be fine at this application
Fixed in latest version 0.2.2
I'm still getting this error with the above MWE:
ERROR: LoadError: MethodError: no method matching TaylorScalar{Float64, 2}(::Tuple{Float64, ChainRulesCore.ZeroTangent})
Closest candidates are:
TaylorScalar{T, N}(::TaylorScalar{T, M}) where {T, N, M}
@ TaylorDiff ~/.julia/packages/TaylorDiff/75xsf/src/scalar.jl:65
TaylorScalar{T, N}(::S, ::S) where {T, S<:Real, N}
@ TaylorDiff ~/.julia/packages/TaylorDiff/75xsf/src/scalar.jl:58
TaylorScalar{T, N}(::S) where {T, S<:Real, N}
@ TaylorDiff ~/.julia/packages/TaylorDiff/75xsf/src/scalar.jl:46
...
Stacktrace:
[1] sign(t::TaylorScalar{Float64, 2})
@ TaylorDiff ~/.julia/packages/TaylorDiff/75xsf/src/codegen.jl:20
[2] __init(prob::ODEProblem{…}, alg::Tsit5{…}, timeseries_init::Tuple{}, ts_init::Tuple{}, ks_init::Tuple{}, recompile::Type{…}; saveat::Tuple{}, tstops::Tuple{}, d_discontinuities::Tuple{}, save_idxs::Nothing, save_everystep::Bool, save_on::Bool, save_start::Bool, save_end::Nothing, callback::Nothing, dense::Bool, calck::Bool, dt::TaylorScalar{…}, dtmin::TaylorScalar{…}, dtmax::TaylorScalar{…}, force_dtmin::Bool, adaptive::Bool, gamma::Rational{…}, abstol::Float64, reltol::Float64, qmin::Rational{…}, qmax::Int64, qsteady_min::Int64, qsteady_max::Int64, beta1::Nothing, beta2::Nothing, qoldinit::Rational{…}, controller::Nothing, fullnormalize::Bool, failfactor::Int64, maxiters::Int64, internalnorm::typeof(DiffEqBase.ODE_DEFAULT_NORM), internalopnorm::typeof(LinearAlgebra.opnorm), isoutofdomain::typeof(DiffEqBase.ODE_DEFAULT_ISOUTOFDOMAIN), unstable_check::typeof(DiffEqBase.ODE_DEFAULT_UNSTABLE_CHECK), verbose::Bool, timeseries_errors::Bool, dense_errors::Bool, advance_to_tstop::Bool, stop_at_next_tstop::Bool, initialize_save::Bool, progress::Bool, progress_steps::Int64, progress_name::String, progress_message::typeof(DiffEqBase.ODE_DEFAULT_PROG_MESSAGE), progress_id::Symbol, userdata::Nothing, allow_extrapolation::Bool, initialize_integrator::Bool, alias_u0::Bool, alias_du0::Bool, initializealg::OrdinaryDiffEq.DefaultInit, kwargs::@Kwargs{})
@ OrdinaryDiffEq ~/.julia/packages/OrdinaryDiffEq/YXsFS/src/solve.jl:120
[3] __solve(::ODEProblem{…}, ::Tsit5{…}; kwargs::@Kwargs{…})
@ OrdinaryDiffEq ~/.julia/packages/OrdinaryDiffEq/YXsFS/src/solve.jl:6
[4] solve_call(_prob::ODEProblem{…}, args::Tsit5{…}; merge_callbacks::Bool, kwargshandle::Nothing, kwargs::@Kwargs{…})
@ DiffEqBase ~/.julia/packages/DiffEqBase/yM6LF/src/solve.jl:612
[5] solve_up(prob::ODEProblem{…}, sensealg::Nothing, u0::Vector{…}, p::ModelingToolkit.MTKParameters{…}, args::Tsit5{…}; kwargs::@Kwargs{…})
@ DiffEqBase ~/.julia/packages/DiffEqBase/yM6LF/src/solve.jl:1080
[6] solve(prob::ODEProblem{…}, args::Tsit5{…}; sensealg::Nothing, u0::Nothing, p::Nothing, wrap::Val{…}, kwargs::@Kwargs{…})
@ DiffEqBase ~/.julia/packages/DiffEqBase/yM6LF/src/solve.jl:1003
[7] (::var"#different_time#5"{ODESolution{…}, ODEProblem{…}, Num})(new_ic::Dict{Num, Float64}, new_params::Dict{Num, Float64}, new_t::TaylorScalar{Float64, 2})
@ Main ~/learning/ODETests/PLI/MWE3.jl:32
[8] (::var"#just_t#6"{var"#different_time#5"{ODESolution{…}, ODEProblem{…}, Num}, Dict{Num, Float64}, Dict{Num, Float64}})(new_t::TaylorScalar{Float64, 2})
@ Main ~/learning/ODETests/PLI/MWE3.jl:37
[9] derivatives
@ ~/.julia/packages/TaylorDiff/75xsf/src/derivative.jl:66 [inlined]
[10] derivative
@ ~/.julia/packages/TaylorDiff/75xsf/src/derivative.jl:54 [inlined]
[11] derivative
@ ~/.julia/packages/TaylorDiff/75xsf/src/derivative.jl:35 [inlined]
[12] derivative
@ ~/.julia/packages/TaylorDiff/75xsf/src/derivative.jl:30 [inlined]
[13] ADTest()
@ Main ~/learning/ODETests/PLI/MWE3.jl:44
[14] top-level scope
@ ~/learning/ODETests/PLI/MWE3.jl:53
[15] include(fname::String)
@ Base.MainInclude ./client.jl:489
[16] top-level scope
@ REPL[5]:1
Oh that's a problem with codegen. I will run you example and make it work tomorrow
Ok so I fixed a minor problem related to convert special tangent types at ChainRules. Now they should be fine
julia> ForwardDiff.derivative(just_t, 1.0)
14.778112197861631
julia> TaylorDiff.derivative(just_t, 1.0, 1)
14.77811219786163