jaxopt's L-BFGS-B with custom gradient not matching with scipy implementation
Context:
I am migrating code (causalnex's dynotears) from a numpy/scipy implementation to a jax implementation. This essentially involves moving from scipy's LBFGS-B to jaxopt's implementation so I can jit this function and run it faster.
Apart from the _func(..) to minimize, the code has a custom _grad(..) function defined for the optimization. I converted both _func() and _grad() to their jax counterparts, and am using jaxopt.LBFGSB with the custom grad function like.
Original numpy/scipy implementation
# initialise matrix, weights and constraints
wa_est = np.zeros(2 * (p_orders + 1) * d_vars**2)
wa_new = np.zeros(2 * (p_orders + 1) * d_vars**2)
rho, alpha, h_value, h_new = 1.0, 0.0, np.inf, np.inf
for n_iter in range(max_iter):
while (rho < 1e20) and (h_new > 0.25 * h_value or h_new == np.inf):
wa_new = sopt.minimize(
_func,
wa_est,
method="L-BFGS-B",
jac=_grad,
bounds=bnds
).x
h_new = _h(wa_new, d_vars, p_orders)
if h_new > 0.25 * h_value:
rho *= 10
wa_est = wa_new
h_value = h_new
alpha += rho * h_value
if h_value <= h_tol:
break
if h_value > h_tol and n_iter == max_iter - 1:
warnings.warn("Failed to converge. Consider increasing max_iter.")
My current jaxopt implementation
# bnds is a list of (lower, upper) tuples, where upper might have None values.
# Make it compatible with what jaxopt.LBFGSB expects
np_bnds = np.array(bnds)
lowers = jnp.array(np_bnds[:, 0].astype(float))
cleaned_uppers = np.where(np_bnds[:, 1] == None, jnp.inf, np_bnds[:, 1])
uppers = jnp.array(cleaned_uppers.astype(float))
jnp_lbfgs_bounds = (lowers, uppers)
lbfgsb_solver = LBFGSB(fun=_func_jax, value_and_grad=True)
for n_iter in range(max_iter):
while (rho < 1e20) and (h_new > 0.25 * h_value or h_new == jnp.inf):
wa_new = lbfgsb_solver.run(
wa_est,
bounds=jnp_lbfgs_bounds
).params
h_new = _h_jax(wa_new, d_vars, p_orders)
if h_new > 0.25 * h_value:
rho *= 10
wa_est = wa_new
h_value = h_new
alpha += rho * h_value
if h_value <= h_tol:
break
if h_value > h_tol and n_iter == max_iter - 1:
warnings.warn("Failed to converge. Consider increasing max_iter.")
I have ensured that _func_jax returns (loss, _grad_jax(params)) compared to _func() which returns just the scalar. I'm not expecting exact answers between scipy/jaxopt implementations since I understand there will be numerical issues, even if seeds are set. But there seems to be a large mismatch between the scipy and jaxopt versions.
I do get some warnings during my run like:
WARNING: jaxopt.ZoomLineSearch: No interval satisfying curvature condition.Consider increasing maximal possible stepsize of the linesearch.
WARNING: jaxopt.ZoomLineSearch: Returning stepsize with sufficient decrease but curvature condition not satisfied.
Would really help to understand what is causing these differences (and if they are expected or not)?
Versions:
jax: 0.4.31
jaxopt: 0.8.3
numpy: 1.23.5
scipy: 1.13.1
Here are the values from my scipy and jaxopt runs, in that order
Final objective: 4.84871 vs 1.5213378e+20 which seems concerning.
Optimized value of wa_new (scipy/numpy)
array([0.00000000e+00, 1.58362135e+00, 0.00000000e+00, 8.64218226e-01,
0.00000000e+00, 5.30177996e-06, 0.00000000e+00, 7.23155393e-06,
1.08522093e-04, 3.21722607e-01, 0.00000000e+00, 1.58837897e+00,
0.00000000e+00, 2.17025397e-01, 5.02410566e-01, 9.24620810e-06,
2.14839298e-01, 4.42866934e-05, 0.00000000e+00, 0.00000000e+00,
4.66671845e-06, 9.10319126e-05, 2.85940340e-05, 0.00000000e+00,
0.00000000e+00, 0.00000000e+00, 0.00000000e+00, 8.72885996e-04,
0.00000000e+00, 0.00000000e+00, 0.00000000e+00, 0.00000000e+00,
0.00000000e+00, 0.00000000e+00, 0.00000000e+00, 1.54268367e-02,
0.00000000e+00, 0.00000000e+00, 0.00000000e+00, 0.00000000e+00,
0.00000000e+00, 0.00000000e+00, 0.00000000e+00, 0.00000000e+00,
3.84189744e-01, 0.00000000e+00, 0.00000000e+00, 0.00000000e+00,
4.18412632e-05, 0.00000000e+00, 0.00000000e+00, 0.00000000e+00,
0.00000000e+00, 1.66891599e-01, 1.05293901e-01, 0.00000000e+00,
0.00000000e+00, 1.10260318e-01, 1.31643372e-01, 0.00000000e+00,
0.00000000e+00, 0.00000000e+00, 8.44444677e-01, 6.77765312e-01,
0.00000000e+00, 5.05999801e-02, 1.00552653e-01, 0.00000000e+00,
0.00000000e+00, 1.54582784e-01, 0.00000000e+00, 1.13265831e+00,
0.00000000e+00, 0.00000000e+00, 7.67622627e-01, 0.00000000e+00,
1.23048506e-01, 0.00000000e+00, 0.00000000e+00, 0.00000000e+00,
3.12075774e-03, 9.20508971e-03, 0.00000000e+00, 0.00000000e+00,
2.24813920e-01, 0.00000000e+00, 0.00000000e+00, 0.00000000e+00,
0.00000000e+00, 3.25643453e-01, 0.00000000e+00, 0.00000000e+00,
5.90566523e-02, 1.91172397e-01, 0.00000000e+00, 3.41801208e-02,
0.00000000e+00, 1.21313949e-01, 7.98333390e-02, 0.00000000e+00])
Optimized value of wa_new (jaxopt)
[0.00000000e+00 5.46834469e-01 0.00000000e+00 2.78253078e-01
7.89659377e-03 1.30613953e-01 0.00000000e+00 6.36773586e-01
8.86588693e-01 1.54841435e+00 0.00000000e+00 8.11460614e-01
0.00000000e+00 2.19483122e-01 3.30503821e-01 1.15662307e-01
8.80117357e-01 1.58519089e-01 0.00000000e+00 7.15284869e-02
7.65197736e-04 1.42049563e+00 1.31482124e-01 1.29014403e-01
0.00000000e+00 0.00000000e+00 0.00000000e+00 1.60738841e-01
0.00000000e+00 2.11134859e-04 7.69150443e-04 0.00000000e+00
4.57316667e-01 5.89280963e-01 1.16028547e+00 1.55092672e-01
8.82897153e-02 0.00000000e+00 8.81871283e-02 0.00000000e+00
1.20857454e-04 3.02751184e-01 7.04137236e-02 0.00000000e+00
4.84485656e-01 3.28849182e-02 8.90319526e-01 5.75939834e-04
4.59955156e-01 0.00000000e+00 7.57855771e-04 0.00000000e+00
0.00000000e+00 3.89564373e-02 5.76256029e-02 2.97483569e-03
5.17785847e-01 1.83005854e-01 2.27913812e-01 1.99340269e-01
0.00000000e+00 4.21245098e-01 4.41471756e-01 3.43075842e-01
0.00000000e+00 4.60929386e-02 9.65917408e-02 7.06027970e-02
3.14517925e-03 1.59189552e-01 1.21413296e-13 5.95388710e-01
2.26187472e-12 8.46469775e-02 8.20553839e-01 8.88178420e-16
1.55108944e-01 2.44511813e-02 0.00000000e+00 0.00000000e+00
6.37098476e-02 4.20914859e-01 1.23776801e-01 1.65834844e-01
4.72670466e-01 9.22673717e-02 8.98242220e-02 0.00000000e+00
0.00000000e+00 1.64450020e-01 1.21536679e-04 5.83359003e-02
8.76099318e-02 1.33519948e-01 0.00000000e+00 9.29643065e-02
9.17113125e-02 3.66380751e-01 1.61518916e-01 7.18804449e-02]