torchdiffeq icon indicating copy to clipboard operation
torchdiffeq copied to clipboard

Batching when using an event function

Open anandtrex opened this issue 4 years ago • 1 comments

I was wondering how batching would work when using an event function. #122 describes a way of handling batching. But I can't figure out how to do this while being able to assign a specific event to a specific batch without re-running the event_fun on the output solution. Thanks.

anandtrex avatar Apr 18 '22 14:04 anandtrex

Currently the only thing we support is having the event function output a non-scalar tensor, e.g. of shape (batch_size,). The solver will then stop when the first event occurs, and checking torch.where(event_fn==0) (or argmin(abs(event_fn))) gives which example triggered this event. For better asynchronous solving, you might want to use torch.distributed instead..

rtqichen avatar Apr 18 '22 15:04 rtqichen