DESC - simsopt coupling. How to obtain derivatives w.r.t boundary harmonics?
Hi!
I recently started working on coupling DESC with simsopt. The goal is to be able to use DESC solver and derivatives in single-stage optimization.
My understanding was that DESC had built-in derivatives using jax. However, I did not find any examples in the documentation. My (unsuccessful) attempt so far is the following (I try to get the derivative of iota w.r.t the (m=0,n=0) R boundary mode):
from desc.equilibrium import Equilibrium
from desc.continuation import solve_continuation_automatic
from desc.geometry import FourierRZToroidalSurface
from desc.grid import LinearGrid
from jax import grad
import jax.numpy as jnp
# Define target function - mean value of iota profile
def fun(R00, equilibrium, grid):
equilibrium.surface.set_coeffs(0,0,R00)
eq = solve_continuation_automatic(equilibrium.copy())[-1]
return jnp.mean(eq.compute('iota', grid)['iota'])
# Initialize DESC equilibrium
surf = FourierRZToroidalSurface.from_input_file('input.LandremanPaul2021_QH')
eq = Equilibrium(surface=surf)
# Define gradient of target function using JAX
grid_1d = LinearGrid(L=100)
f = lambda R00: fun(R00, eq, grid_1d)
dfdR = lambda R00: grad(f, argnums=0)(R00)
# Evaluate gradient
dfdR(1.0)
Which does not work - Jax is unhappy:
---------------------------------------------------------------------------
TracerArrayConversionError Traceback (most recent call last)
Cell In[5], line 1
----> 1 dfdR(1.0)
Cell In[2], line 12, in <lambda>(R00)
10 grid_1d = LinearGrid(L=100)
11 f = lambda R00: fun(R00, eq, grid_1d)
---> 12 dfdR = lambda R00: grad(f, argnums=0)(R00)
[... skipping hidden 10 frame]
Cell In[2], line 11, in <lambda>(R00)
8 eq = Equilibrium(surface=surf)
10 grid_1d = LinearGrid(L=100)
---> 11 f = lambda R00: fun(R00, eq, grid_1d)
12 dfdR = lambda R00: grad(f, argnums=0)(R00)
Cell In[2], line 2, in fun(R00, equilibrium, grid)
1 def fun(R00, equilibrium, grid):
----> 2 equilibrium.surface.set_coeffs(0,0,R00)
3 eq = solve_continuation_automatic(equilibrium.copy())[-1]
5 return jnp.mean(eq.compute('iota', grid)['iota'])
File /opt/homebrew/Caskroom/miniconda/base/envs/simsopt/lib/python3.8/site-packages/desc/geometry/surface.py:252, in FourierRZToroidalSurface.set_coeffs(self, m, n, R, Z)
247 def set_coeffs(self, m, n=0, R=None, Z=None):
248 """Set specific Fourier coefficients."""
249 m, n, R, Z = (
250 np.atleast_1d(m),
251 np.atleast_1d(n),
--> 252 np.atleast_1d(R),
253 np.atleast_1d(Z),
254 )
255 m, n, R, Z = np.broadcast_arrays(m, n, R, Z)
256 for mm, nn, RR, ZZ in zip(m, n, R, Z):
File <__array_function__ internals>:200, in atleast_1d(*args, **kwargs)
File /opt/homebrew/Caskroom/miniconda/base/envs/simsopt/lib/python3.8/site-packages/numpy/core/shape_base.py:65, in atleast_1d(*arys)
63 res = []
64 for ary in arys:
---> 65 ary = asanyarray(ary)
66 if ary.ndim == 0:
67 result = ary.reshape(1)
File /opt/homebrew/Caskroom/miniconda/base/envs/simsopt/lib/python3.8/site-packages/jax/_src/core.py:598, in Tracer.__array__(self, *args, **kw)
597 def __array__(self, *args, **kw):
--> 598 raise TracerArrayConversionError(self)
TracerArrayConversionError: The numpy.ndarray conversion method __array__() was called on the JAX Tracer object Traced<ConcreteArray(1.0, dtype=float64, weak_type=True)>with<JVPTrace(level=2/0)> with
primal = 1.0
tangent = Traced<ShapedArray(float64[], weak_type=True)>with<JaxprTrace(level=1/0)> with
pval = (ShapedArray(float64[], weak_type=True), None)
recipe = LambdaBinding()
See https://jax.readthedocs.io/en/latest/errors.html#jax.errors.TracerArrayConversionError
Is this the right approach to get such derivatives? If not, can you point me towards some example where this is done? Thank you!
Mea culpa, I just found the desc.objectives which seems to do what I want!
The desc.objectives won't directly give what you want, as we do not use the boundary modes as arguments to our functions, but rather the coefficients of the basis describing the whole solution (JAX gives derivatives wrt the arguments of the function). What you want, I assume, is the derivative of iota with respect to the boundary modes at constant force error (which one would find with VMEC normally by repeatedly calling the equilibrium solve with slightly different boundary shapes)
DESC can do this with one eq solve (or one jacobian creation for the force balance objective) This is what is internally done in the perturbations.py optimal_perturb function, which more or less attempts to find a change in the inputs $c$ (this is the boundary coefficients, profiles, and Psi) that will improve the objective $g$ (e.g. iota, QS, etc) while approximately maintaining zero force balance (see this paper for more detail)
We'll put together a tutorial script to show how to do this and post it here later today
Awesome, thank you
I think something like this is what you want:
import desc
from desc.objectives import ObjectiveFunction, RotationalTransform, ForceBalance
from desc.optimize import ProximalProjection
eq = desc.examples.get("DSHAPE_CURRENT")
objective = ObjectiveFunction(RotationalTransform(eq, target=1.2, loss_function="mean"))
constraint = ObjectiveFunction(ForceBalance(eq))
prox = ProximalProjection(objective, constraint, eq)
prox.build()
x = prox.x(eq) # vector of all the boundary/profile dofs
jac = prox.jac_scaled(x) # d(iota)/d(boundary + profile dofs)
import jax
# unpack the big matrix into dict of individual components
jax.vmap(prox.unpack_state, (0, None))(jac, False)[0]
The ProximalProjection thing basically takes care of solving the equilibrium and combining all the jacobians together to give you the derivative you want wrt the boundary terms. However be aware that it maintains state, and so if you evaluate at a different x it will update the equilibrium to that new state. It uses a perturb/re-solve method, so it only really works if you call it with a sequence of xs that aren't changing too quickly (which is fine for gradient based optimization but might cause issues if you're doing global stuff, let me know and we can probably modify it to deal with that)
@abaillod did that work? Is there anything else we can help with?
Closing as assumed resolved, feel free to re-open if there are still questions