torchdiffeq
torchdiffeq copied to clipboard
Batching when using an event function
I was wondering how batching would work when using an event function.
#122 describes a way of handling batching. But I can't figure out how to do this while being able to assign a specific event to a specific batch without re-running the event_fun on the output solution. Thanks.
Currently the only thing we support is having the event function output a non-scalar tensor, e.g. of shape (batch_size,). The solver will then stop when the first event occurs, and checking torch.where(event_fn==0) (or argmin(abs(event_fn))) gives which example triggered this event. For better asynchronous solving, you might want to use torch.distributed instead..