Training an Event Function
Hi,
I've been trying to train a drift and event function but the parameters of my event function are not changing. Here's a simplified portion of the code I've been working with that only deals with the event function:
loss_fn = nn.MSELoss(reduction='sum')
func = ODEFunc().to(device).double()
event = ODEEvent().to(device).double()
optimizer = optim.Adam(event.parameters(), lr=0.001)
for itr in range(iters):
optimizer.zero_grad()
event_t, pred = odeint_event(func, v0, t0, event_fn=event, odeint_interface=odeint_adjoint, method='bosh3', atol=1e-6)
loss = loss_fn(event_t, true_t)
loss.backward()
optimizer.step()
The parameters of event never change between the iterations and list(event.parameters())[0].grad is always None. How do I get the gradients w.r.t. the time points so my event function can learn? Any help is appreciated.
Ah, the parameters of the event function currently need to be part of the state if you want gradients to pass into the event function parameters. I've written a colab notebook based off of your code: https://colab.research.google.com/drive/1bMNBGMdSX_NQ-wCJqb0h81WcUknc_GM9?usp=sharing
Thank you for your response and for taking the time to write that code. The event training definitely works now with that correction. If I now want to train a drift function simultaneously do I just include:
pred_v = odeint(func, v0, t)
loss = loss_fn(event_t, true_t) + loss_fn(pred_v, true_v)
in the loop where t is from t0 to event_t? Do I need to include the drift function parameters in the state as well in this case?
It sounds like you want to define a loss on multiple values of t values between 0 and event_t? This can be done in a single odeint call:
ts = torch.linspace(0, 1, T) * event_t # construct T time values to query z(t)
preds = odeint(func, v0, ts)
where preds is of shape (T, ...) where preds[i] is the solution at time ts[i].
Although, I would consider if the type of loss you want to compute is
L = int_0^{event_t} R(z(t), y(t) dt
then it might make sense to compute this within the ODE solver (i.e. using dL/dt = R(z(t)), and computing R as part of the ODE function).
Do I need to include the drift function parameters in the state as well in this case?
Shouldn't need to do this; the drift function is handled automatically. Just if ever odeint_adjoint is used, then the drift function needs to be implemented as a nn.Module.
(Sorry for the many gotchas.)
Thanks for the suggestion, I'll try that out. I had one final question. Is it possible to train a drift and event function simultaneously by computing a loss only on the trajectory between t0 and event_t?
Currently, I am computing two losses: one for the predicted event_t and true event_t and one for the predicted and true trajectory between t0 and event_t. I was wondering if it would be possible to train the drift and event functions using only the loss of the trajectory between t0 and event_t. Or do I have to explicitly define a loss on the event_t (when I remove my loss for the event_t, all my event grads are 0)?
Thank you for all the help.