torchdiffeq
torchdiffeq copied to clipboard
How to pass extra paramaters of func to odeint?
I look the defination of deint in torchdiffeq , but do not find a paramater to pass extra paramaters like the args paramater in scipy.integrate.odeint. Is there any other way to pass paramaters to odeint besides define a global variable?
Yeah, just define it anywhere. In order to use odeint_adjoint, it's good practice to define them as part of the module.
global_params = ...
class ODEfunc(nn.Module):
def __init__(self):
self.parameters = nn.Parameter(some_tensor_we_want_to_optimize_ie_compute_gradients_for)
def forward(self, t, x):
p = self.parameters()
external_p = global_params
# some ops regarding t, x, p, external_p
return ...
If you use odeint, gradient will be computed w.r.t. external_p, but odeint_adjoint will only do it for p.
Thanks for your answer!