ReverseDiff.jl
ReverseDiff.jl copied to clipboard
ReverseDiff for derivatives of custom types?
Hi,
I need to evaluate the gradient of a function with my custom type/operations. For example if f(x) = x^3, I want the derivative obtained with the usual rules of calculus (i.e. df(x)=3*x^2), to be evaluated for typeof(x) = foo. Of course I have pre-defined the operations *(a::Number,b::foo) and ^(a::foo,n::Integer), etc...
I know that Zygote can easily do this, but it is about an order of magnitude slower than ReverseDiff for my problem and Real inputs, so I would rather prefer to use a compiled tape and ReverseDiff.
Thanks,
I don't think ReverseDiff is designed for that. Your best bet is to convert your type into an array.