torchdiffeq icon indicating copy to clipboard operation
torchdiffeq copied to clipboard

Training an Event Function

Open rrm45 opened this issue 3 years ago • 4 comments

Hi,

I've been trying to train a drift and event function but the parameters of my event function are not changing. Here's a simplified portion of the code I've been working with that only deals with the event function: loss_fn = nn.MSELoss(reduction='sum') func = ODEFunc().to(device).double() event = ODEEvent().to(device).double() optimizer = optim.Adam(event.parameters(), lr=0.001) for itr in range(iters): optimizer.zero_grad() event_t, pred = odeint_event(func, v0, t0, event_fn=event, odeint_interface=odeint_adjoint, method='bosh3', atol=1e-6) loss = loss_fn(event_t, true_t) loss.backward() optimizer.step()

The parameters of event never change between the iterations and list(event.parameters())[0].grad is always None. How do I get the gradients w.r.t. the time points so my event function can learn? Any help is appreciated.

rrm45 avatar Aug 10 '22 05:08 rrm45

Ah, the parameters of the event function currently need to be part of the state if you want gradients to pass into the event function parameters. I've written a colab notebook based off of your code: https://colab.research.google.com/drive/1bMNBGMdSX_NQ-wCJqb0h81WcUknc_GM9?usp=sharing

rtqichen avatar Aug 10 '22 14:08 rtqichen

Thank you for your response and for taking the time to write that code. The event training definitely works now with that correction. If I now want to train a drift function simultaneously do I just include: pred_v = odeint(func, v0, t) loss = loss_fn(event_t, true_t) + loss_fn(pred_v, true_v) in the loop where t is from t0 to event_t? Do I need to include the drift function parameters in the state as well in this case?

rrm45 avatar Aug 10 '22 21:08 rrm45

It sounds like you want to define a loss on multiple values of t values between 0 and event_t? This can be done in a single odeint call:

ts = torch.linspace(0, 1, T) * event_t  # construct T time values to query z(t)
preds = odeint(func, v0, ts) 

where preds is of shape (T, ...) where preds[i] is the solution at time ts[i].

Although, I would consider if the type of loss you want to compute is

L = int_0^{event_t} R(z(t), y(t) dt

then it might make sense to compute this within the ODE solver (i.e. using dL/dt = R(z(t)), and computing R as part of the ODE function).

Do I need to include the drift function parameters in the state as well in this case?

Shouldn't need to do this; the drift function is handled automatically. Just if ever odeint_adjoint is used, then the drift function needs to be implemented as a nn.Module.

(Sorry for the many gotchas.)

rtqichen avatar Aug 10 '22 21:08 rtqichen

Thanks for the suggestion, I'll try that out. I had one final question. Is it possible to train a drift and event function simultaneously by computing a loss only on the trajectory between t0 and event_t?

Currently, I am computing two losses: one for the predicted event_t and true event_t and one for the predicted and true trajectory between t0 and event_t. I was wondering if it would be possible to train the drift and event functions using only the loss of the trajectory between t0 and event_t. Or do I have to explicitly define a loss on the event_t (when I remove my loss for the event_t, all my event grads are 0)?

Thank you for all the help.

rrm45 avatar Aug 26 '22 19:08 rrm45