torchdiffeq icon indicating copy to clipboard operation
torchdiffeq copied to clipboard

How to pass extra paramaters of func to odeint?

Open shifttttttt opened this issue 2 years ago • 3 comments

I look the defination of deint in torchdiffeq , but do not find a paramater to pass extra paramaters like the args paramater in scipy.integrate.odeint. Is there any other way to pass paramaters to odeint besides define a global variable?

shifttttttt avatar Jan 23 '24 12:01 shifttttttt

Yeah, just define it anywhere. In order to use odeint_adjoint, it's good practice to define them as part of the module.


global_params = ...

class ODEfunc(nn.Module):

  def __init__(self):
    self.parameters = nn.Parameter(some_tensor_we_want_to_optimize_ie_compute_gradients_for)

  def forward(self, t, x):
    p = self.parameters()
    external_p = global_params
    # some ops regarding t, x, p, external_p
    return ...

If you use odeint, gradient will be computed w.r.t. external_p, but odeint_adjoint will only do it for p.

rtqichen avatar Mar 12 '24 02:03 rtqichen

Thanks for your answer!

shifttttttt avatar Mar 16 '24 06:03 shifttttttt