jaxopt icon indicating copy to clipboard operation
jaxopt copied to clipboard

jaxopt's L-BFGS-B with custom gradient not matching with scipy implementation

Open jithendaraa opened this issue 1 year ago • 1 comments

Context:

I am migrating code (causalnex's dynotears) from a numpy/scipy implementation to a jax implementation. This essentially involves moving from scipy's LBFGS-B to jaxopt's implementation so I can jit this function and run it faster.

Apart from the _func(..) to minimize, the code has a custom _grad(..) function defined for the optimization. I converted both _func() and _grad() to their jax counterparts, and am using jaxopt.LBFGSB with the custom grad function like.

Original numpy/scipy implementation

# initialise matrix, weights and constraints
wa_est = np.zeros(2 * (p_orders + 1) * d_vars**2)
wa_new = np.zeros(2 * (p_orders + 1) * d_vars**2)
rho, alpha, h_value, h_new = 1.0, 0.0, np.inf, np.inf

for n_iter in range(max_iter):
    while (rho < 1e20) and (h_new > 0.25 * h_value or h_new == np.inf):
        wa_new = sopt.minimize(
            _func, 
            wa_est, 
            method="L-BFGS-B", 
            jac=_grad, 
            bounds=bnds
        ).x
        h_new = _h(wa_new, d_vars, p_orders)
        if h_new > 0.25 * h_value:
            rho *= 10

    wa_est = wa_new
    h_value = h_new
    alpha += rho * h_value
    if h_value <= h_tol:
        break
    if h_value > h_tol and n_iter == max_iter - 1:
        warnings.warn("Failed to converge. Consider increasing max_iter.")

My current jaxopt implementation

# bnds is a list of (lower, upper) tuples, where upper might have None values. 
# Make it compatible with what jaxopt.LBFGSB expects
np_bnds = np.array(bnds)
lowers = jnp.array(np_bnds[:, 0].astype(float))
cleaned_uppers = np.where(np_bnds[:, 1] == None, jnp.inf, np_bnds[:, 1])
uppers = jnp.array(cleaned_uppers.astype(float))
jnp_lbfgs_bounds = (lowers, uppers)

lbfgsb_solver = LBFGSB(fun=_func_jax, value_and_grad=True)

for n_iter in range(max_iter):
    while (rho < 1e20) and (h_new > 0.25 * h_value or h_new == jnp.inf):
        wa_new = lbfgsb_solver.run(
            wa_est, 
            bounds=jnp_lbfgs_bounds
        ).params

        h_new = _h_jax(wa_new, d_vars, p_orders)
        if h_new > 0.25 * h_value:
            rho *= 10

    wa_est = wa_new
    h_value = h_new
    alpha += rho * h_value
    if h_value <= h_tol:
        break
    if h_value > h_tol and n_iter == max_iter - 1:
        warnings.warn("Failed to converge. Consider increasing max_iter.")

I have ensured that _func_jax returns (loss, _grad_jax(params)) compared to _func() which returns just the scalar. I'm not expecting exact answers between scipy/jaxopt implementations since I understand there will be numerical issues, even if seeds are set. But there seems to be a large mismatch between the scipy and jaxopt versions.

I do get some warnings during my run like:

WARNING: jaxopt.ZoomLineSearch: No interval satisfying curvature condition.Consider increasing maximal possible stepsize of the linesearch.
WARNING: jaxopt.ZoomLineSearch: Returning stepsize with sufficient decrease but curvature condition not satisfied.

Would really help to understand what is causing these differences (and if they are expected or not)?

Versions:

jax: 0.4.31
jaxopt: 0.8.3
numpy: 1.23.5
scipy:  1.13.1

jithendaraa avatar Jan 26 '25 05:01 jithendaraa

Here are the values from my scipy and jaxopt runs, in that order

Final objective: 4.84871 vs 1.5213378e+20 which seems concerning.

Optimized value of wa_new (scipy/numpy)

array([0.00000000e+00, 1.58362135e+00, 0.00000000e+00, 8.64218226e-01,
       0.00000000e+00, 5.30177996e-06, 0.00000000e+00, 7.23155393e-06,
       1.08522093e-04, 3.21722607e-01, 0.00000000e+00, 1.58837897e+00,
       0.00000000e+00, 2.17025397e-01, 5.02410566e-01, 9.24620810e-06,
       2.14839298e-01, 4.42866934e-05, 0.00000000e+00, 0.00000000e+00,
       4.66671845e-06, 9.10319126e-05, 2.85940340e-05, 0.00000000e+00,
       0.00000000e+00, 0.00000000e+00, 0.00000000e+00, 8.72885996e-04,
       0.00000000e+00, 0.00000000e+00, 0.00000000e+00, 0.00000000e+00,
       0.00000000e+00, 0.00000000e+00, 0.00000000e+00, 1.54268367e-02,
       0.00000000e+00, 0.00000000e+00, 0.00000000e+00, 0.00000000e+00,
       0.00000000e+00, 0.00000000e+00, 0.00000000e+00, 0.00000000e+00,
       3.84189744e-01, 0.00000000e+00, 0.00000000e+00, 0.00000000e+00,
       4.18412632e-05, 0.00000000e+00, 0.00000000e+00, 0.00000000e+00,
       0.00000000e+00, 1.66891599e-01, 1.05293901e-01, 0.00000000e+00,
       0.00000000e+00, 1.10260318e-01, 1.31643372e-01, 0.00000000e+00,
       0.00000000e+00, 0.00000000e+00, 8.44444677e-01, 6.77765312e-01,
       0.00000000e+00, 5.05999801e-02, 1.00552653e-01, 0.00000000e+00,
       0.00000000e+00, 1.54582784e-01, 0.00000000e+00, 1.13265831e+00,
       0.00000000e+00, 0.00000000e+00, 7.67622627e-01, 0.00000000e+00,
       1.23048506e-01, 0.00000000e+00, 0.00000000e+00, 0.00000000e+00,
       3.12075774e-03, 9.20508971e-03, 0.00000000e+00, 0.00000000e+00,
       2.24813920e-01, 0.00000000e+00, 0.00000000e+00, 0.00000000e+00,
       0.00000000e+00, 3.25643453e-01, 0.00000000e+00, 0.00000000e+00,
       5.90566523e-02, 1.91172397e-01, 0.00000000e+00, 3.41801208e-02,
       0.00000000e+00, 1.21313949e-01, 7.98333390e-02, 0.00000000e+00])

Optimized value of wa_new (jaxopt)

[0.00000000e+00 5.46834469e-01 0.00000000e+00 2.78253078e-01
 7.89659377e-03 1.30613953e-01 0.00000000e+00 6.36773586e-01
 8.86588693e-01 1.54841435e+00 0.00000000e+00 8.11460614e-01
 0.00000000e+00 2.19483122e-01 3.30503821e-01 1.15662307e-01
 8.80117357e-01 1.58519089e-01 0.00000000e+00 7.15284869e-02
 7.65197736e-04 1.42049563e+00 1.31482124e-01 1.29014403e-01
 0.00000000e+00 0.00000000e+00 0.00000000e+00 1.60738841e-01
 0.00000000e+00 2.11134859e-04 7.69150443e-04 0.00000000e+00
 4.57316667e-01 5.89280963e-01 1.16028547e+00 1.55092672e-01
 8.82897153e-02 0.00000000e+00 8.81871283e-02 0.00000000e+00
 1.20857454e-04 3.02751184e-01 7.04137236e-02 0.00000000e+00
 4.84485656e-01 3.28849182e-02 8.90319526e-01 5.75939834e-04
 4.59955156e-01 0.00000000e+00 7.57855771e-04 0.00000000e+00
 0.00000000e+00 3.89564373e-02 5.76256029e-02 2.97483569e-03
 5.17785847e-01 1.83005854e-01 2.27913812e-01 1.99340269e-01
 0.00000000e+00 4.21245098e-01 4.41471756e-01 3.43075842e-01
 0.00000000e+00 4.60929386e-02 9.65917408e-02 7.06027970e-02
 3.14517925e-03 1.59189552e-01 1.21413296e-13 5.95388710e-01
 2.26187472e-12 8.46469775e-02 8.20553839e-01 8.88178420e-16
 1.55108944e-01 2.44511813e-02 0.00000000e+00 0.00000000e+00
 6.37098476e-02 4.20914859e-01 1.23776801e-01 1.65834844e-01
 4.72670466e-01 9.22673717e-02 8.98242220e-02 0.00000000e+00
 0.00000000e+00 1.64450020e-01 1.21536679e-04 5.83359003e-02
 8.76099318e-02 1.33519948e-01 0.00000000e+00 9.29643065e-02
 9.17113125e-02 3.66380751e-01 1.61518916e-01 7.18804449e-02]

jithendaraa avatar Jan 26 '25 05:01 jithendaraa