Question about Class FuseBlock7 in common.py
I've encountered the question with certain implementation details in the code about Class FuseBlock7 in common.py
class FuseBlock7(nn.Module):
def __init__(self, channels):
super(FuseBlock7, self).__init__()
self.fre = nn.Conv2d(channels, channels, 3, 1, 1)
self.spa = nn.Conv2d(channels, channels, 3, 1, 1)
self.fre_att = Attention(dim=channels)
self.spa_att = Attention(dim=channels)
self.fuse = nn.Sequential(nn.Conv2d(2*channels, channels, 3, 1, 1), nn.Conv2d(channels, 2*channels, 3, 1, 1), nn.Sigmoid())
def forward(self, spa, fre):
ori = spa
fre = self.fre(fre)
spa = self.spa(spa)
fre = self.fre_att(fre, spa)+fre
spa = self.fre_att(spa, fre)+spa
fuse = self.fuse(torch.cat((fre, spa), 1))
fre_a, spa_a = fuse.chunk(2, dim=1)
spa = spa_a * spa
fre = fre * fre_a
res = fre + spa
res = torch.nan_to_num(res, nan=1e-5, posinf=1e-5, neginf=1e-5)
return res
The origin code: spa = self.fre_att(spa, fre)+spa
Shoud it be spa = self.spa_att(spa, fre)+spa ?
Hello, the origin code is correct.
fre and spa use the same Attention function,and self.spa_att = Attention(dim=channels) is not used. I have some questions about this. Could you explain the reason? Thank you
Why did my gradient explode after using your FreBlock9, with a loss of nan,can you help me? Thanks !
@LanCole Thanks for pointing this out! You’re right — this is a typo in the code. It doesn’t affect the model’s performance, but we’ll fix it in the next update for clarity.
@D6582 Thanks for your question! We also encountered this issue. It turns out the gradient explosion comes from the Fourier transform part in FreBlock9. We’re looking into better fixes, but for now you can try normalizing the FFT/ IFFT and adding a small epsilon to avoid division by very small values — this helps stabilize training.