pytorch_fft
pytorch_fft copied to clipboard
slow speed on torch.DataParallel models
Hi,
I try to use this framework to compute FFT with torch.DataParallel(model)
but it seems that with the same batch_size in one GPU and 4 GPUs, the fft will consume much more time:
with batch_size 16 on one GPU [784 X 8192 size with 1d fft]:
it will cost about 0.60s in fft, 0.21s in ifft.
but with batch_size 64 on 4 GPUs:
it will cost 4s in fft, and 1s in ifft.
So could you provide enhancements to multigpu FFT? thanks.
For example: with such lines of FFT: `
class testFFT(nn.Module):
def __init__(self):
super(testFFT, self).__init__()
def forward(self,x):
new_out = afft.Fft(x, Variable(torch.zeros(x.size())).cuda())
return x
x = Variable(torch.Tensor(64*14*14,8192)).cuda()
temp_fft = torch.nn.DataParallel(testFFT())
temp_fft.cuda()
out=temp_fft(x)
` with one GPU (K80), it only costs 0.35s. with 2 GPUs, it will costs about 6s!