Support torch.compile()?
Hi, thank you for providing this toolbox. Is there any future work planned toward allowing torch.compile() for ptwt?
We do support torch.jit to some degree. Its similar to torch.compile. See also https://github.com/v0lta/PyTorch-Wavelet-Toolbox/blob/main/tests/test_jit.py .
Sincere thanks for your prompt response. In my opinion, torch.compile is more convenient and faster compared to torch.jit. I believe that once ptwt supports torch.compile, it will be even more convenient for the community and user-friendly.---Original---From: "Moritz @.>Date: Fri, Aug 9, 2024 15:10 PMTo: @.>;Cc: @.>@.>;Subject: Re: [v0lta/PyTorch-Wavelet-Toolbox] Support torch.compile()? (Issue#101)
We do support torch.jit to some degree. Its similar to torch.compile. See also https://github.com/v0lta/PyTorch-Wavelet-Toolbox/blob/main/tests/test_jit.py .
—Reply to this email directly, view it on GitHub, or unsubscribe.You are receiving this because you authored the thread.Message ID: @.***>
We haven't tested it, does a working torch.jit test imply that torch.compile works as well?
Torch.compile() indicates that some operations in ptwt are not supported.
---Original--- From: "Moritz @.> Date: Thu, Aug 22, 2024 19:54 PM To: @.>; Cc: @.@.>; Subject: Re: [v0lta/PyTorch-Wavelet-Toolbox] Support torch.compile()? (Issue#101)
We haven't tested it, does a working torch.jit test imply that torch.compile works as well?
— Reply to this email directly, view it on GitHub, or unsubscribe. You are receiving this because you authored the thread.Message ID: @.***>
Could you provide an example of how to apply learnable wavelet transforms to a 3D tensor? The input has the shape [B, L, C], where B: batch_size, L: length, and C: channel. I want to perform the wavelet transform along the L dimension
Sure:
from tqdm import tqdm
import ptwt, torch
from ptwt.wavelets_learnable import ProductFilter
torch.manual_seed(42)
aten = torch.randn(32, 32, 32, 32)
wavelet = ProductFilter(torch.randn(4), torch.randn(4),
torch.randn(4), torch.randn(4))
opt = torch.optim.RMSprop(wavelet.parameters(), lr=0.01)
for _i in (bar := tqdm(range(5000))):
res = ptwt.waverec3(ptwt.wavedec3(aten, wavelet, level=4), wavelet)
cost = torch.mean((res - aten)**2) + wavelet.wavelet_loss()
cost.backward()
opt.step()
opt.zero_grad()
bar.set_description(f"cost: {cost.detach().numpy():2.4e}")
Pass [1, B, L, C] into wavedec3 if you want to transform the batch dimension. I hope this helps. See https://arxiv.org/pdf/2004.09569 for more information regarding learnable wavelets.