jaxopt
jaxopt copied to clipboard
Convenience API for QPs
Currently, the interface for QPs is like this:
params_obj = (Q, c)
params_eq = (A, b)
params_ineq = (G, h)
osqp = jaxopt.OSQP()
osqp.run(init, params_obj, params_eq, params_ineq)
I propose that we also add the following convenience API:
def guaranteed_quadratic_function(params, params_obj):
[...]
osqp = jaxopt.OSQP(fun=guaranteed_quadratic_function)
osqp.run(init, params_obj, params_eq, params_ineq)
The idea is that we can, internally, automatically compute Q and c for the user using AD
grad = jax.grad(self.fun)
c = grad(zeros, params_obj)
Q = jax.jacobian(grad)(zeros, params_obj)
Similarly, for solvers like OSQP that can leverage matvecs, we can derive the matvec of Q automatically.
CC @Algue-Rythme
or maybe
def guaranteed_quadratic_function(params, *params_obj):
[...]
instead
Good idea, I will take a look !
This is now done for OSQP! Keeping the issue open as it also need to be done for CvxpyQP, EqQP, BoxCDQP and it needs to be documented.