Using `torchsde.BrownianInterval` instead of `torchsde.BrownianTree` in class `BatchedBrownianTree`
Is your feature request related to a problem? Please describe. When I was doing some optimization for my pipeline, i found that the BrownianTree somehow took a bit more time.
Describe the solution you'd like.
I further dig into torchsde document, and found that they encouraged to use BrownianInterval to have best benefits for underlying structure utilization. The BrownianTree is actually just an abstraction layer of the BrownianInterval and as we all know, python function calls take time!
Code:
#diffusers/src/diffusers/schedulers/scheduling_dpmsolver_sde.py:41
self.trees = [torchsde.BrownianTree(t0, w0, t1, entropy=s, **kwargs) for s in seed]
# Modified
self.trees = [torchsde.BrownianInterval(t0, t1, size=w0.shape, dtype=w0.dtype, device=w0.device, cache_size=None, entropy=s, **kwargs) for s in seed]
Additional context. torchsde doc link
Cc: @yiyixuxu
@sayakpaul @yiyixuxu Any idea on this? This is a low-hanging fruit i think
A friendly ping here @yiyixuxu
hi @dianyo yes we would welcome a PR! (sorry for the very delayed response)
Hi @yiyixuxu, I've create a PR and ping you on the request!