ReconFormer
ReconFormer copied to clipboard
Fix data-consistency module
Summary
Fixes #10
This PR addresses the problem with data-consistency module and 2D Fourier transform functions fft2, and ifft2. The data-consistency module has been updated, fft2c and ifft2c functions are added to transforms.py.
Problem definition
The below data-consistency module does not work:
https://github.com/guopengf/ReconFormer/blob/e2e0d5e6e58e04ad1c77a1151e63cf56bec21fb1/models/Recurrent_Transformer.py#L13-L55
This is due to some errors in fft2 and ifft2 functions in transforms.py:
https://github.com/guopengf/ReconFormer/blob/e2e0d5e6e58e04ad1c77a1151e63cf56bec21fb1/data/transforms.py#L73-L107
To Reproduce
import torch
from backbones.reconformer.reconformer import DataConsistencyInKspace
resolution = 320
device = 'cuda:0'
x = torch.randn((1, 2, resolution, resolution)).to(device)
k0 = torch.randn((1, 2, resolution, resolution)).to(device)
mask = torch.randn((1, 1, resolution, resolution)).to(device)
dc = DataConsistencyInKspace()
out = dc(x, k0, mask)
print(f"Input shape: {x.shape}")
print(f"Output shape: {out.shape}")
46 k0 = k0.permute(0, 2, 3, 1)
47 mask = mask.permute(0, 2, 3, 1)
...
--> 122 data = torch.fft.fft(data, 2, normalized=normalized)
123 data = fftshift(data, dim=(-3, -2))
124 return data
TypeError: fft_fft() got an unexpected keyword argument 'normalized'