TaylorDiff.jl icon indicating copy to clipboard operation
TaylorDiff.jl copied to clipboard

Can't differentiate an ODE solver due to lack of isnan and other type errors.

Open orebas opened this issue 1 year ago • 2 comments

TaylorDiff.jl seems to throw an error when I try to differentiate a fairly simple ODE solver. There is an error on the MTK side, but even after the workaround there (See https://discourse.julialang.org/t/error-trying-to-forwarddiff-through-an-ode-solver/114339/6) I can't get Taylor diff.jl to work.

MWE:

using ModelingToolkit, DifferentialEquations
using TaylorDiff, ForwardDiff
using DifferentiationInterface, Enzyme, Zygote, ReverseDiff
using SciMLSensitivity
#import Base.isnan
#function isnan(x::TaylorScalar{Float64, 2})
#	return false
#end

function ADTest()
	@parameters a
	@variables t x1(t) 
	D = Differential(t)
	states = [x1]
	parameters = [a]

	@named pre_model = ODESystem([D(x1) ~ a * x1], t, states, parameters)
	model = structural_simplify(pre_model)

	ic = Dict(x1 => 1.0)
	p_true = Dict(a => 2.0)

	problem = ODEProblem{true, SciMLBase.FullSpecialize}(model, ic, [0.0, 1.0], p_true)
	soln = ModelingToolkit.solve(problem, Tsit5(), abstol = 1e-12, reltol = 1e-12)
	display(soln(0.5, idxs = [x1]))

	function different_time(new_ic, new_params, new_t)
		#newprob = ODEProblem{true, SciMLBase.FullSpecialize}(model, new_ic, [0.0, new_t*2], new_params)
		#newprob = remake(problem, u0=new_ic, tspan = [0.0, new_t], p = new_params)
		newprob = remake(problem, u0 = new_ic, tspan = [0.0, new_t], p=new_params)
		newprob = remake(newprob, u0 = typeof(new_t).(newprob.u0))
        new_soln = ModelingToolkit.solve(newprob, Tsit5(), abstol = 1e-12, reltol = 1e-12)
		return (soln(new_t, idxs = [x1]))
	end

	function just_t(new_t)
		return different_time(ic, p_true, new_t)[1]
	end
	display(different_time(ic, p_true, 2e-5))
	display(just_t(0.5))

	
    #display(ForwardDiff.derivative(just_t,1.0))
	display(TaylorDiff.derivative(just_t,1.0,1))  #isnan error
    #display(value_and_gradient(just_t, AutoForwardDiff(), 1.0)) 
	#display(value_and_gradient(just_t, AutoReverseDiff(), 1.0)) 	
    #display(value_and_gradient(just_t, AutoEnzyme(Enzyme.Reverse), 1.0)) 
	#display(value_and_gradient(just_t, AutoEnzyme(Enzyme.Forward), 1.0)) 
    #display(value_and_gradient(just_t, AutoZygote(), 1.0)) 
	
end

ADTest()

orebas avatar May 16 '24 20:05 orebas

Running the above, the error is

ERROR: LoadError: MethodError: no method matching isnan(::TaylorScalar{Float64, 2})

Closest candidates are:
  isnan(::Missing)
   @ Base missing.jl:101
  isnan(::BigFloat)
   @ Base mpfr.jl:982
  isnan(::Complex)
   @ Base complex.jl:151
  ...

Stacktrace:
  [1] _any(f::typeof(isnan), itr::Tuple{TaylorScalar{Float64, 2}, TaylorScalar{Float64, 2}}, ::Colon)
    @ Base ./reduce.jl:1220
  [2] any(f::Function, itr::Tuple{TaylorScalar{Float64, 2}, TaylorScalar{Float64, 2}})
    @ Base ./reduce.jl:1235
  [3] get_concrete_tspan(prob::ODEProblem{…}, isadapt::Bool, kwargs::@Kwargs{…}, p::ModelingToolkit.MTKParameters{…})
    @ DiffEqBase ~/.julia/packages/DiffEqBase/X5SZr/src/solve.jl:1287
  [4] get_concrete_problem(prob::ODEProblem{…}, isadapt::Bool; kwargs::@Kwargs{…})
    @ DiffEqBase ~/.julia/packages/DiffEqBase/X5SZr/src/solve.jl:1169
  [5] solve_up(prob::ODEProblem{…}, sensealg::Nothing, u0::Vector{…}, p::ModelingToolkit.MTKParameters{…}, args::Tsit5{…}; kwargs::@Kwargs{…})
    @ DiffEqBase ~/.julia/packages/DiffEqBase/X5SZr/src/solve.jl:1074
  [6] solve(prob::ODEProblem{…}, args::Tsit5{…}; sensealg::Nothing, u0::Nothing, p::Nothing, wrap::Val{…}, kwargs::@Kwargs{…})
    @ DiffEqBase ~/.julia/packages/DiffEqBase/X5SZr/src/solve.jl:1003
  [7] (::var"#different_time#1"{ODESolution{…}, ODEProblem{…}, Num})(new_ic::Dict{Num, Float64}, new_params::Dict{Num, Float64}, new_t::TaylorScalar{Float64, 2})
    @ Main ~/learning/ODETests/PLI/MWE3.jl:32
  [8] (::var"#just_t#2"{var"#different_time#1"{ODESolution{…}, ODEProblem{…}, Num}, Dict{Num, Float64}, Dict{Num, Float64}})(new_t::TaylorScalar{Float64, 2})
    @ Main ~/learning/ODETests/PLI/MWE3.jl:37
  [9] derivative
    @ ~/.julia/packages/TaylorDiff/zNnz2/src/derivative.jl:28 [inlined]
 [10] derivative
    @ ~/.julia/packages/TaylorDiff/zNnz2/src/derivative.jl:18 [inlined]
 [11] ADTest()
    @ Main ~/learning/ODETests/PLI/MWE3.jl:44
 [12] top-level scope
    @ ~/learning/ODETests/PLI/MWE3.jl:53
 [13] include(fname::String)
    @ Base.MainInclude ./client.jl:489
 [14] top-level scope
    @ REPL[1]:1

orebas avatar May 16 '24 20:05 orebas

If I go ahead and try to define isnan, (you can uncomment 4 lines near the top of the MWE), the error becomes

ERROR: LoadError: Non-concrete element type inside of an `Array` detected.
Arrays with non-concrete element types, such as
`Array{Union{Float32,Float64}}`, are not supported by the
differential equation solvers. Anyways, this is bad for
performance so you don't want to be doing this!

If this was a mistake, promote the element types to be
all the same. If this was intentional, for example,
using Unitful.jl with different unit values, then use
an array type which has fast broadcast support for
heterogeneous values such as the ArrayPartition
from RecursiveArrayTools.jl. For example:

```julia
using RecursiveArrayTools
x = ArrayPartition([1.0,2.0],[1f0,2f0])
y = ArrayPartition([3.0,4.0],[3f0,4f0])
x .+ y # fast, stable, and usable as u0 into DiffEq!

Element type: Any

Some of the types have been truncated in the stacktrace for improved reading. To emit complete information in the stack trace, evaluate TruncatedStacktraces.VERBOSE[] = true and re-run the code.

Stacktrace: [1] solve_call(_prob::ODEProblem{…}, args::Tsit5{…}; merge_callbacks::Bool, kwargshandle::Nothing, kwargs::@Kwargs{…}) @ DiffEqBase ~/.julia/packages/DiffEqBase/X5SZr/src/solve.jl:592 [2] solve_up(prob::ODEProblem{…}, sensealg::Nothing, u0::Vector{…}, p::ModelingToolkit.MTKParameters{…}, args::Tsit5{…}; kwargs::@Kwargs{…}) @ DiffEqBase ~/.julia/packages/DiffEqBase/X5SZr/src/solve.jl:1080 [3] solve(prob::ODEProblem{…}, args::Tsit5{…}; sensealg::Nothing, u0::Nothing, p::Nothing, wrap::Val{…}, kwargs::@Kwargs{…}) @ DiffEqBase ~/.julia/packages/DiffEqBase/X5SZr/src/solve.jl:1003 [4] (::var"#different_time#3"{ODESolution{…}, ODEProblem{…}, Num})(new_ic::Dict{Num, Float64}, new_params::Dict{Num, Float64}, new_t::TaylorScalar{Float64, 2}) @ Main ~/learning/ODETests/PLI/MWE3.jl:32 [5] (::var"#just_t#4"{var"#different_time#3"{ODESolution{…}, ODEProblem{…}, Num}, Dict{Num, Float64}, Dict{Num, Float64}})(new_t::TaylorScalar{Float64, 2}) @ Main ~/learning/ODETests/PLI/MWE3.jl:37 [6] derivative @ ~/.julia/packages/TaylorDiff/zNnz2/src/derivative.jl:28 [inlined] [7] derivative @ ~/.julia/packages/TaylorDiff/zNnz2/src/derivative.jl:18 [inlined] [8] ADTest() @ Main ~/learning/ODETests/PLI/MWE3.jl:44 [9] top-level scope @ ~/learning/ODETests/PLI/MWE3.jl:53 [10] include(fname::String) @ Base.MainInclude ./client.jl:489 [11] top-level scope @ REPL[1]:1 in expression starting at /home/orebas/learning/ODETests/PLI/MWE3.jl:53 Some type information was truncated. Use show(err) to see complete types.

orebas avatar May 16 '24 20:05 orebas

This is identified previously: #35, due to the type system inconsistency issues. Unfortunately I haven't figured out a good way to handle this...

tansongchen avatar May 17 '24 14:05 tansongchen

Ok I now believe not <: Real is a design error and needs to be fixed. I initiated a fix at https://github.com/JuliaDiff/TaylorDiff.jl/tree/subtype-number , when it is done you will be fine at this application

tansongchen avatar May 22 '24 19:05 tansongchen

Fixed in latest version 0.2.2

tansongchen avatar May 22 '24 19:05 tansongchen

I'm still getting this error with the above MWE:

ERROR: LoadError: MethodError: no method matching TaylorScalar{Float64, 2}(::Tuple{Float64, ChainRulesCore.ZeroTangent})

Closest candidates are:
  TaylorScalar{T, N}(::TaylorScalar{T, M}) where {T, N, M}
   @ TaylorDiff ~/.julia/packages/TaylorDiff/75xsf/src/scalar.jl:65
  TaylorScalar{T, N}(::S, ::S) where {T, S<:Real, N}
   @ TaylorDiff ~/.julia/packages/TaylorDiff/75xsf/src/scalar.jl:58
  TaylorScalar{T, N}(::S) where {T, S<:Real, N}
   @ TaylorDiff ~/.julia/packages/TaylorDiff/75xsf/src/scalar.jl:46
  ...

Stacktrace:
  [1] sign(t::TaylorScalar{Float64, 2})
    @ TaylorDiff ~/.julia/packages/TaylorDiff/75xsf/src/codegen.jl:20
  [2] __init(prob::ODEProblem{…}, alg::Tsit5{…}, timeseries_init::Tuple{}, ts_init::Tuple{}, ks_init::Tuple{}, recompile::Type{…}; saveat::Tuple{}, tstops::Tuple{}, d_discontinuities::Tuple{}, save_idxs::Nothing, save_everystep::Bool, save_on::Bool, save_start::Bool, save_end::Nothing, callback::Nothing, dense::Bool, calck::Bool, dt::TaylorScalar{…}, dtmin::TaylorScalar{…}, dtmax::TaylorScalar{…}, force_dtmin::Bool, adaptive::Bool, gamma::Rational{…}, abstol::Float64, reltol::Float64, qmin::Rational{…}, qmax::Int64, qsteady_min::Int64, qsteady_max::Int64, beta1::Nothing, beta2::Nothing, qoldinit::Rational{…}, controller::Nothing, fullnormalize::Bool, failfactor::Int64, maxiters::Int64, internalnorm::typeof(DiffEqBase.ODE_DEFAULT_NORM), internalopnorm::typeof(LinearAlgebra.opnorm), isoutofdomain::typeof(DiffEqBase.ODE_DEFAULT_ISOUTOFDOMAIN), unstable_check::typeof(DiffEqBase.ODE_DEFAULT_UNSTABLE_CHECK), verbose::Bool, timeseries_errors::Bool, dense_errors::Bool, advance_to_tstop::Bool, stop_at_next_tstop::Bool, initialize_save::Bool, progress::Bool, progress_steps::Int64, progress_name::String, progress_message::typeof(DiffEqBase.ODE_DEFAULT_PROG_MESSAGE), progress_id::Symbol, userdata::Nothing, allow_extrapolation::Bool, initialize_integrator::Bool, alias_u0::Bool, alias_du0::Bool, initializealg::OrdinaryDiffEq.DefaultInit, kwargs::@Kwargs{})
    @ OrdinaryDiffEq ~/.julia/packages/OrdinaryDiffEq/YXsFS/src/solve.jl:120
  [3] __solve(::ODEProblem{…}, ::Tsit5{…}; kwargs::@Kwargs{…})
    @ OrdinaryDiffEq ~/.julia/packages/OrdinaryDiffEq/YXsFS/src/solve.jl:6
  [4] solve_call(_prob::ODEProblem{…}, args::Tsit5{…}; merge_callbacks::Bool, kwargshandle::Nothing, kwargs::@Kwargs{…})
    @ DiffEqBase ~/.julia/packages/DiffEqBase/yM6LF/src/solve.jl:612
  [5] solve_up(prob::ODEProblem{…}, sensealg::Nothing, u0::Vector{…}, p::ModelingToolkit.MTKParameters{…}, args::Tsit5{…}; kwargs::@Kwargs{…})
    @ DiffEqBase ~/.julia/packages/DiffEqBase/yM6LF/src/solve.jl:1080
  [6] solve(prob::ODEProblem{…}, args::Tsit5{…}; sensealg::Nothing, u0::Nothing, p::Nothing, wrap::Val{…}, kwargs::@Kwargs{…})
    @ DiffEqBase ~/.julia/packages/DiffEqBase/yM6LF/src/solve.jl:1003
  [7] (::var"#different_time#5"{ODESolution{…}, ODEProblem{…}, Num})(new_ic::Dict{Num, Float64}, new_params::Dict{Num, Float64}, new_t::TaylorScalar{Float64, 2})
    @ Main ~/learning/ODETests/PLI/MWE3.jl:32
  [8] (::var"#just_t#6"{var"#different_time#5"{ODESolution{…}, ODEProblem{…}, Num}, Dict{Num, Float64}, Dict{Num, Float64}})(new_t::TaylorScalar{Float64, 2})
    @ Main ~/learning/ODETests/PLI/MWE3.jl:37
  [9] derivatives
    @ ~/.julia/packages/TaylorDiff/75xsf/src/derivative.jl:66 [inlined]
 [10] derivative
    @ ~/.julia/packages/TaylorDiff/75xsf/src/derivative.jl:54 [inlined]
 [11] derivative
    @ ~/.julia/packages/TaylorDiff/75xsf/src/derivative.jl:35 [inlined]
 [12] derivative
    @ ~/.julia/packages/TaylorDiff/75xsf/src/derivative.jl:30 [inlined]
 [13] ADTest()
    @ Main ~/learning/ODETests/PLI/MWE3.jl:44
 [14] top-level scope
    @ ~/learning/ODETests/PLI/MWE3.jl:53
 [15] include(fname::String)
    @ Base.MainInclude ./client.jl:489
 [16] top-level scope
    @ REPL[5]:1

orebas avatar May 22 '24 21:05 orebas

Oh that's a problem with codegen. I will run you example and make it work tomorrow

tansongchen avatar May 23 '24 03:05 tansongchen

Ok so I fixed a minor problem related to convert special tangent types at ChainRules. Now they should be fine

julia> ForwardDiff.derivative(just_t, 1.0)
14.778112197861631

julia> TaylorDiff.derivative(just_t, 1.0, 1)
14.77811219786163

tansongchen avatar May 28 '24 20:05 tansongchen