加上nn.unfold 模型 计算结果不对
error log | 日志或报错信息 | ログ
model | 模型 | モデル
- original model import torch import torch.nn as nn import numpy as np import ncnn
class ERB(nn.Module): def init(self, erb_subband_1, erb_subband_2, nfft=512, high_lim=8000, fs=16000): super().init() erb_filters = self.erb_filter_banks(erb_subband_1, erb_subband_2, nfft, high_lim, fs) nfreqs = nfft // 2 + 1 self.erb_subband_1 = erb_subband_1 self.erb_fc = nn.Linear(nfreqs - erb_subband_1, erb_subband_2, bias=False) self.ierb_fc = nn.Linear(erb_subband_2, nfreqs - erb_subband_1, bias=False) self.erb_fc.weight = nn.Parameter(erb_filters, requires_grad=False) self.ierb_fc.weight = nn.Parameter(erb_filters.T, requires_grad=False)
def hz2erb(self, freq_hz):
erb_f = 21.4 * np.log10(0.00437 * freq_hz + 1)
return erb_f
def erb2hz(self, erb_f):
freq_hz = (10 ** (erb_f / 21.4) - 1) / 0.00437
return freq_hz
def erb_filter_banks(self, erb_subband_1, erb_subband_2, nfft=512, high_lim=8000, fs=16000):
low_lim = erb_subband_1 / nfft * fs
erb_low = self.hz2erb(low_lim)
erb_high = self.hz2erb(high_lim)
erb_points = np.linspace(erb_low, erb_high, erb_subband_2)
bins = np.round(self.erb2hz(erb_points) / fs * nfft).astype(np.int32)
erb_filters = np.zeros([erb_subband_2, nfft // 2 + 1], dtype=np.float32)
erb_filters[0, bins[0]:bins[1]] = (bins[1] - np.arange(bins[0], bins[1]) + 1e-12) \
/ (bins[1] - bins[0] + 1e-12)
for i in range(erb_subband_2 - 2):
erb_filters[i + 1, bins[i]:bins[i + 1]] = (np.arange(bins[i], bins[i + 1]) - bins[i] + 1e-12) \
/ (bins[i + 1] - bins[i] + 1e-12)
erb_filters[i + 1, bins[i + 1]:bins[i + 2]] = (bins[i + 2] - np.arange(bins[i + 1], bins[i + 2]) + 1e-12) \
/ (bins[i + 2] - bins[i + 1] + 1e-12)
erb_filters[-1, bins[-2]:bins[-1] + 1] = 1 - erb_filters[-2, bins[-2]:bins[-1] + 1]
erb_filters = erb_filters[:, erb_subband_1:]
return torch.from_numpy(np.abs(erb_filters))
def bm(self, x):
"""x: (B,C,T,F)"""
x_low = x[..., :self.erb_subband_1]
x_high = self.erb_fc(x[..., self.erb_subband_1:])
return torch.cat([x_low, x_high], dim=-1)
def bs(self, x_erb):
"""x: (B,C,T,F_erb)"""
x_erb_low = x_erb[..., :self.erb_subband_1]
x_erb_high = self.ierb_fc(x_erb[..., self.erb_subband_1:])
return torch.cat([x_erb_low, x_erb_high], dim=-1)
class SFE(nn.Module): """Subband Feature Extraction"""
def __init__(self, kernel_size=3, stride=1):
super().__init__()
self.kernel_size = kernel_size
self.unfold = nn.Unfold(kernel_size=(1, kernel_size), stride=(1, stride), padding=(0, (kernel_size - 1) // 2))
def forward(self, x):
"""x: (B,C,T,F)"""
B, C, T, F = x.size()
xs = self.unfold(x)
xs = xs.reshape(B, C * self.kernel_size,T,F)
return xs
class StreamGTCRN(nn.Module): def init(self): super().init() self.erb = ERB(65, 64) self.sfe = SFE(3, 1) self.unfold = nn.Unfold(kernel_size=(1, 3), stride=(1, 1), padding=(0, (3 - 1) // 2))
def forward(self, spec):
spec_ref = spec # (B,F,T,2)
spec_real = spec[..., 0].permute(0,2,1)
spec_imag = spec[..., 1].permute(0,2,1)
spec_mag = torch.sqrt(spec_real**2 + spec_imag**2 + 1e-12)
feat = torch.stack([spec_mag, spec_real, spec_imag], dim=1) # (B,3,T,257)
feat = self.erb.bm(feat) # (B,3,T,129)
B, C, T, F = feat.size()
feat1 = self.unfold(feat.clone()).reshape(B, C * 3,T,F) # (B,9,T,129)
return feat, feat1
import pnnx
if name == 'main': # 设置PyTorch种子 torch.manual_seed(420) # 主种子,影响PyTorch的CPU和GPU随机数生成
x = torch.ones(1, 257, 10, 2) + torch.randn(1, 257, 10, 2)
model = StreamGTCRN()
model.eval()
model_jit = torch.jit.trace(model.cpu(), (x), check_trace=True)
pt_model_path = "StreamGTCRN.pt"
# model_jit.save(pt_model_path)
# opt_model = pnnx.export(model, pt_model_path, x)
model_ = torch.jit.load(pt_model_path).cpu()
pt_out0, pt_out1 = model_.forward(x)
net = ncnn.Net()
net.load_param("StreamGTCRN.ncnn.param")
net.load_model("StreamGTCRN.ncnn.bin")
input1 = x.numpy()
#
ex = net.create_extractor()
ex.input("in0", ncnn.Mat(input1).clone())
res, out0 = ex.extract("out0")
res, out1 = ex.extract("out1")
out = torch.tensor(np.array(out0))
out1 = torch.tensor(np.array(out1))
#
#
#
print(f"pt_out0 shape {pt_out0.size()}, out shape {out.size()},max diff {torch.max(torch.abs(pt_out0 - out))}")
how to reproduce | 复现步骤 | 再現方法
-
model_jit.save(pt_model_path)
opt_model = pnnx.export(model, pt_model_path, x)解除注释,运行,加注释再运行,pt_out0 shape torch.Size([1, 3, 10, 129]), out shape torch.Size([3, 10, 129]),max diff 6.7004618644714355
torch.stack有问题,这个算子的计算逻辑和pytorch不同。用unsqueeze和concat做替换。