PyTorch-Wavelet-Toolbox icon indicating copy to clipboard operation
PyTorch-Wavelet-Toolbox copied to clipboard

Support torch.compile()?

Open lijun2005 opened this issue 1 year ago • 6 comments

Hi, thank you for providing this toolbox. Is there any future work planned toward allowing torch.compile() for ptwt?

lijun2005 avatar Aug 08 '24 05:08 lijun2005

We do support torch.jit to some degree. Its similar to torch.compile. See also https://github.com/v0lta/PyTorch-Wavelet-Toolbox/blob/main/tests/test_jit.py .

v0lta avatar Aug 09 '24 07:08 v0lta

Sincere thanks for your prompt response. In my opinion, torch.compile is more convenient and faster compared to torch.jit. I believe that once ptwt supports torch.compile, it will be even more convenient for the community and user-friendly.---Original---From: "Moritz @.>Date: Fri, Aug 9, 2024 15:10 PMTo: @.>;Cc: @.>@.>;Subject: Re: [v0lta/PyTorch-Wavelet-Toolbox] Support torch.compile()? (Issue#101) We do support torch.jit to some degree. Its similar to torch.compile. See also https://github.com/v0lta/PyTorch-Wavelet-Toolbox/blob/main/tests/test_jit.py .

—Reply to this email directly, view it on GitHub, or unsubscribe.You are receiving this because you authored the thread.Message ID: @.***>

lijun2005 avatar Aug 09 '24 07:08 lijun2005

We haven't tested it, does a working torch.jit test imply that torch.compile works as well?

v0lta avatar Aug 22 '24 11:08 v0lta

Torch.compile() indicates that some operations in ptwt are not supported.

---Original--- From: "Moritz @.> Date: Thu, Aug 22, 2024 19:54 PM To: @.>; Cc: @.@.>; Subject: Re: [v0lta/PyTorch-Wavelet-Toolbox] Support torch.compile()? (Issue#101)

We haven't tested it, does a working torch.jit test imply that torch.compile works as well?

— Reply to this email directly, view it on GitHub, or unsubscribe. You are receiving this because you authored the thread.Message ID: @.***>

lijun2005 avatar Aug 22 '24 13:08 lijun2005

Could you provide an example of how to apply learnable wavelet transforms to a 3D tensor? The input has the shape [B, L, C], where B: batch_size, L: length, and C: channel. I want to perform the wavelet transform along the L dimension

lijun2005 avatar Sep 21 '24 05:09 lijun2005

Sure:

from tqdm import tqdm
import ptwt, torch
from ptwt.wavelets_learnable import ProductFilter

torch.manual_seed(42)
aten = torch.randn(32, 32, 32, 32)
wavelet = ProductFilter(torch.randn(4), torch.randn(4),
                        torch.randn(4), torch.randn(4))
opt = torch.optim.RMSprop(wavelet.parameters(), lr=0.01)
 
for _i in (bar := tqdm(range(5000))):
      res = ptwt.waverec3(ptwt.wavedec3(aten, wavelet, level=4), wavelet)
      cost = torch.mean((res - aten)**2) + wavelet.wavelet_loss()
      cost.backward()
      opt.step()
      opt.zero_grad()
      bar.set_description(f"cost: {cost.detach().numpy():2.4e}")

Pass [1, B, L, C] into wavedec3 if you want to transform the batch dimension. I hope this helps. See https://arxiv.org/pdf/2004.09569 for more information regarding learnable wavelets.

v0lta avatar Sep 25 '24 16:09 v0lta