DESC icon indicating copy to clipboard operation
DESC copied to clipboard

DESC - simsopt coupling. How to obtain derivatives w.r.t boundary harmonics?

Open abaillod opened this issue 2 years ago • 5 comments

Hi!

I recently started working on coupling DESC with simsopt. The goal is to be able to use DESC solver and derivatives in single-stage optimization.

My understanding was that DESC had built-in derivatives using jax. However, I did not find any examples in the documentation. My (unsuccessful) attempt so far is the following (I try to get the derivative of iota w.r.t the (m=0,n=0) R boundary mode):

from desc.equilibrium import Equilibrium
from desc.continuation import solve_continuation_automatic
from desc.geometry import FourierRZToroidalSurface
from desc.grid import LinearGrid
from jax import grad
import jax.numpy as jnp

# Define target function - mean value of iota profile
def fun(R00, equilibrium, grid):
    equilibrium.surface.set_coeffs(0,0,R00)
    eq = solve_continuation_automatic(equilibrium.copy())[-1]

    return jnp.mean(eq.compute('iota', grid)['iota'])

# Initialize DESC equilibrium
surf = FourierRZToroidalSurface.from_input_file('input.LandremanPaul2021_QH')
eq = Equilibrium(surface=surf)

# Define gradient of target function using JAX
grid_1d = LinearGrid(L=100)
f = lambda R00: fun(R00, eq, grid_1d)
dfdR = lambda R00: grad(f, argnums=0)(R00)

# Evaluate gradient
dfdR(1.0)

Which does not work - Jax is unhappy:

---------------------------------------------------------------------------
TracerArrayConversionError                Traceback (most recent call last)
Cell In[5], line 1
----> 1 dfdR(1.0)

Cell In[2], line 12, in <lambda>(R00)
     10 grid_1d = LinearGrid(L=100)
     11 f = lambda R00: fun(R00, eq, grid_1d)
---> 12 dfdR = lambda R00: grad(f, argnums=0)(R00)

    [... skipping hidden 10 frame]

Cell In[2], line 11, in <lambda>(R00)
      8 eq = Equilibrium(surface=surf)
     10 grid_1d = LinearGrid(L=100)
---> 11 f = lambda R00: fun(R00, eq, grid_1d)
     12 dfdR = lambda R00: grad(f, argnums=0)(R00)

Cell In[2], line 2, in fun(R00, equilibrium, grid)
      1 def fun(R00, equilibrium, grid):
----> 2     equilibrium.surface.set_coeffs(0,0,R00)
      3     eq = solve_continuation_automatic(equilibrium.copy())[-1]
      5     return jnp.mean(eq.compute('iota', grid)['iota'])

File /opt/homebrew/Caskroom/miniconda/base/envs/simsopt/lib/python3.8/site-packages/desc/geometry/surface.py:252, in FourierRZToroidalSurface.set_coeffs(self, m, n, R, Z)
    247 def set_coeffs(self, m, n=0, R=None, Z=None):
    248     """Set specific Fourier coefficients."""
    249     m, n, R, Z = (
    250         np.atleast_1d(m),
    251         np.atleast_1d(n),
--> 252         np.atleast_1d(R),
    253         np.atleast_1d(Z),
    254     )
    255     m, n, R, Z = np.broadcast_arrays(m, n, R, Z)
    256     for mm, nn, RR, ZZ in zip(m, n, R, Z):

File <__array_function__ internals>:200, in atleast_1d(*args, **kwargs)

File /opt/homebrew/Caskroom/miniconda/base/envs/simsopt/lib/python3.8/site-packages/numpy/core/shape_base.py:65, in atleast_1d(*arys)
     63 res = []
     64 for ary in arys:
---> 65     ary = asanyarray(ary)
     66     if ary.ndim == 0:
     67         result = ary.reshape(1)

File /opt/homebrew/Caskroom/miniconda/base/envs/simsopt/lib/python3.8/site-packages/jax/_src/core.py:598, in Tracer.__array__(self, *args, **kw)
    597 def __array__(self, *args, **kw):
--> 598   raise TracerArrayConversionError(self)

TracerArrayConversionError: The numpy.ndarray conversion method __array__() was called on the JAX Tracer object Traced<ConcreteArray(1.0, dtype=float64, weak_type=True)>with<JVPTrace(level=2/0)> with
  primal = 1.0
  tangent = Traced<ShapedArray(float64[], weak_type=True)>with<JaxprTrace(level=1/0)> with
    pval = (ShapedArray(float64[], weak_type=True), None)
    recipe = LambdaBinding()
See https://jax.readthedocs.io/en/latest/errors.html#jax.errors.TracerArrayConversionError

Is this the right approach to get such derivatives? If not, can you point me towards some example where this is done? Thank you!

abaillod avatar Feb 07 '24 15:02 abaillod

Mea culpa, I just found the desc.objectives which seems to do what I want!

abaillod avatar Feb 07 '24 16:02 abaillod

The desc.objectives won't directly give what you want, as we do not use the boundary modes as arguments to our functions, but rather the coefficients of the basis describing the whole solution (JAX gives derivatives wrt the arguments of the function). What you want, I assume, is the derivative of iota with respect to the boundary modes at constant force error (which one would find with VMEC normally by repeatedly calling the equilibrium solve with slightly different boundary shapes)

DESC can do this with one eq solve (or one jacobian creation for the force balance objective) This is what is internally done in the perturbations.py optimal_perturb function, which more or less attempts to find a change in the inputs $c$ (this is the boundary coefficients, profiles, and Psi) that will improve the objective $g$ (e.g. iota, QS, etc) while approximately maintaining zero force balance (see this paper for more detail)

We'll put together a tutorial script to show how to do this and post it here later today

dpanici avatar Feb 07 '24 16:02 dpanici

Awesome, thank you

abaillod avatar Feb 07 '24 17:02 abaillod

I think something like this is what you want:

import desc
from desc.objectives import ObjectiveFunction, RotationalTransform, ForceBalance
from desc.optimize import ProximalProjection

eq = desc.examples.get("DSHAPE_CURRENT")

objective = ObjectiveFunction(RotationalTransform(eq, target=1.2, loss_function="mean"))
constraint = ObjectiveFunction(ForceBalance(eq))

prox = ProximalProjection(objective, constraint, eq)
prox.build()

x = prox.x(eq) # vector of all the boundary/profile dofs
jac = prox.jac_scaled(x) # d(iota)/d(boundary + profile dofs)

import jax
# unpack the big matrix into dict of individual components
jax.vmap(prox.unpack_state, (0, None))(jac, False)[0]

The ProximalProjection thing basically takes care of solving the equilibrium and combining all the jacobians together to give you the derivative you want wrt the boundary terms. However be aware that it maintains state, and so if you evaluate at a different x it will update the equilibrium to that new state. It uses a perturb/re-solve method, so it only really works if you call it with a sequence of xs that aren't changing too quickly (which is fine for gradient based optimization but might cause issues if you're doing global stuff, let me know and we can probably modify it to deal with that)

f0uriest avatar Feb 07 '24 21:02 f0uriest

@abaillod did that work? Is there anything else we can help with?

f0uriest avatar Mar 06 '24 21:03 f0uriest

Closing as assumed resolved, feel free to re-open if there are still questions

dpanici avatar Aug 20 '24 19:08 dpanici