ForwardDiff.jl icon indicating copy to clipboard operation
ForwardDiff.jl copied to clipboard

Recommended method for creating custom derivatives / gradients using Duals?

Open sdewaele opened this issue 6 years ago • 6 comments

The documentation currently describes how to add custom derivative definitions using DiffRules. However, it seems that this only covers basic custom derivatives. For example, I don't know how this approach would support a black box function (e.g. an external binary) that returns both the function value and the gradient, without having to call it twice. Therefore, I am interested to directly implement a method for Dual.

Below I show an example of my attempt to create an example. Is this the recommended way to do this? I have put this together after reading source code referred to in this issue. It would be nice to have the case of ℝⁿ → ℝⁿ as well. Perhaps this example can be a starting point to document this way of creating custom derivatives.

using ForwardDiff

module MyModule
  using ForwardDiff
  using LinearAlgebra
  # ℝ → ℝ ——————————————————————————————————————————————
  "Original function"
  f0(x) = x^2

  """Returns f0(x) and its derivative
  In actual usage, this will be a function ForwardDiff cannot differentiate through,
  e.g. because it calls an external binary.
  """
  fg(x) = (v=f0(x),d=2x)
  
  "test function - calls `fg`"
  f(x) = fg(x).v

  "Custom derivative for `f` using `fg`"
  function f(d::ForwardDiff.Dual{T}) where T
    x = ForwardDiff.value(d)
    y = fg(x)
    ForwardDiff.Dual{T}(y.v,y.d*ForwardDiff.partials(d))
  end

  # ℝⁿ → ℝ ——————————————————————————————————————————————
  f_vs0(x) = x[1]+x[2]^2
  fg_vs(x) = (v=f_vs0(x),d=[1.0,2x[2]])
  f_vs(x) = fg_vs(x).v
  function f_vs(d::Vector{D}) where D<:ForwardDiff.Dual
    x = ForwardDiff.value.(d)
    y = fg_vs(x)
    b_in = zip(collect.(ForwardDiff.partials.(d))...)
    b_arr = map(x->y.d⋅x,b_in)
    p = ForwardDiff.Partials((b_arr...,))
    D(y.v,p)
  end
end

## Testing

# ℝ → ℝ
x = 2.3
b = ForwardDiff.derivative(MyModule.f,x)
display(b == ForwardDiff.derivative(MyModule.f0,x))

# ℝⁿ → ℝ
v = [-0.3,3.4]
g = ForwardDiff.gradient(MyModule.f_vs,v)
display(all(g .== ForwardDiff.gradient(MyModule.f_vs0,v)))

## ℝ → ℝⁿ → ℝ
f = x->MyModule.f_vs([-1.2x,-0.5/x])
x2 = 0.7
b2 = ForwardDiff.derivative(f,x2)
f0 = x->MyModule.f_vs0([-1.2x,-0.5/x])
display(b2 == ForwardDiff.derivative(f0,x2))

sdewaele avatar Oct 11 '19 13:10 sdewaele

Feel free to take inspiration from https://github.com/JuliaDiff/ForwardDiff.jl/pull/165.

KristofferC avatar Oct 11 '19 13:10 KristofferC

Thanks!

In fact I already looked at parts of your code, e.g. _propagate_user_gradient! as I was writing this. If there are plans to progress this branch, that would be great. Regardless, it is probably useful to have a recommendation available for hand-written custom gradients. For example, if I understand correctly, the current macro does not support computing the function value and gradient in a single pass. There may be other reasons to prefer a hand-written gradient.

sdewaele avatar Oct 11 '19 13:10 sdewaele

the current macro does not support computing the function value and gradient in a single pass.

Yeah, the API in that PR should be changed to allow for this (define one function that returns both function value and gradient)

KristofferC avatar Oct 11 '19 13:10 KristofferC

@YingboMa wrote a pedagogical example for how to do this. https://gist.github.com/YingboMa/c22dcf8239a62e01b27ac679dfe5d4c5

We should get this into the docs, even if we do get a macro for it.

oxinabox avatar Mar 14 '21 21:03 oxinabox

Thanks @oxinabox for providing the link!

sdewaele avatar Mar 14 '21 22:03 sdewaele

The gist there is approximately the same as in my old PR (https://github.com/JuliaDiff/ForwardDiff.jl/pull/165/files#diff-5632cec511f57cd4be617f25c09846cde440b8fa54d35abcfd546952ab4f25b2R116-R131).

KristofferC avatar Mar 15 '21 08:03 KristofferC