ncnn icon indicating copy to clipboard operation
ncnn copied to clipboard

gtcrn模型转换问题

Open cp917 opened this issue 8 months ago • 2 comments

error log | 日志或报错信息 | ログ

pnnxparam = ./model2.pnnx.param pnnxbin = ./model2.pnnx.bin
pnnxpy = ./model2_pnnx.py
ncnnparam = ./model2.ncnn.param ncnnbin = ./model2.ncnn.bin ncnnpy = ./model2_ncnn.py fp16 = 1 optlevel = 2 device = cpu inputshape = [1,257,1,2]f32,[2,1,16,16,33]f32,[2,3,1,1,16]f32,[2,1,33,16]f32 inputshape2 = customop = moduleop = ############# pass_level0 inline module = ConvBlock inline module = Mask inline module = SFE inline module = StreamDecoder inline module = StreamEncoder inline module = StreamGTConvBlock inline module = StreamTRA inline module = modules.convolution.StreamConv2d inline module = modules.convolution.StreamConvTranspose2d inline module = ConvBlock inline module = Mask inline module = SFE inline module = StreamDecoder inline module = StreamEncoder inline module = StreamGTConvBlock inline module = StreamTRA inline module = modules.convolution.StreamConv2d inline module = modules.convolution.StreamConvTranspose2d


############# pass_level1 ############# pass_level2 ############# pass_level3 ############# pass_level4 ############# pass_level5 ############# pass_ncnn fallback batch axis 233 for operand 5 insert_reshape_linear 4 insert_reshape_linear 4 select along batch axis 0 is not supported select along batch axis 0 is not supported slice along batch axis is not supported slice along batch axis is not supported slice along batch axis is not supported slice along batch axis is not supported slice along batch axis is not supported slice along batch axis is not supported slice along batch axis is not supported slice along batch axis is not supported slice along batch axis is not supported slice along batch axis is not supported slice along batch axis is not supported slice along batch axis is not supported slice_copy along batch axis is not supported slice_copy along batch axis is not supported slice_copy along batch axis is not supported slice_copy along batch axis is not supported slice_copy along batch axis is not supported slice_copy along batch axis is not supported slice_copy along batch axis is not supported slice_copy along batch axis is not supported slice_copy along batch axis is not supported slice_copy along batch axis is not supported slice_copy along batch axis is not supported slice_copy along batch axis is not supported slice_copy along batch axis is not supported slice_copy along batch axis is not supported reshape tensor with batch index 1 is not supported yet! reshape tensor with batch index 1 is not supported yet! reshape tensor with batch index 1 is not supported yet! reshape tensor with batch index 1 is not supported yet! reshape tensor with batch index 1 is not supported yet! reshape tensor with batch index 1 is not supported yet! reshape tensor with batch index 1 is not supported yet! reshape tensor with batch index 1 is not supported yet! ignore Crop select_0 param dim=0 ignore Crop select_0 param index=0 ignore Crop select_2 param dim=0 ignore Crop select_2 param index=1

model | 模型 | モデル

  1. original model class ERB(nn.Module): def init(self, erb_subband_1, erb_subband_2, nfft=512, high_lim=8000, fs=16000): super().init() erb_filters = self.erb_filter_banks(erb_subband_1, erb_subband_2, nfft, high_lim, fs) nfreqs = nfft//2 + 1 self.erb_subband_1 = erb_subband_1 self.erb_fc = nn.Linear(nfreqs-erb_subband_1, erb_subband_2, bias=False) self.ierb_fc = nn.Linear(erb_subband_2, nfreqs-erb_subband_1, bias=False) self.erb_fc.weight = nn.Parameter(erb_filters, requires_grad=False) self.ierb_fc.weight = nn.Parameter(erb_filters.T, requires_grad=False)

    def hz2erb(self, freq_hz): erb_f = 21.4np.log10(0.00437freq_hz + 1) return erb_f

    def erb2hz(self, erb_f): freq_hz = (10**(erb_f/21.4) - 1)/0.00437 return freq_hz

    def erb_filter_banks(self, erb_subband_1, erb_subband_2, nfft=512, high_lim=8000, fs=16000): low_lim = erb_subband_1/nfft * fs erb_low = self.hz2erb(low_lim) erb_high = self.hz2erb(high_lim) erb_points = np.linspace(erb_low, erb_high, erb_subband_2) bins = np.round(self.erb2hz(erb_points)/fs*nfft).astype(np.int32) erb_filters = np.zeros([erb_subband_2, nfft // 2 + 1], dtype=np.float32)

     erb_filters[0, bins[0]:bins[1]] = (bins[1] - np.arange(bins[0], bins[1]) + 1e-12) \
                                             / (bins[1] - bins[0] + 1e-12)
     for i in range(erb_subband_2-2):
         erb_filters[i + 1, bins[i]:bins[i+1]] = (np.arange(bins[i], bins[i+1]) - bins[i] + 1e-12)\
                                                 / (bins[i+1] - bins[i] + 1e-12)
         erb_filters[i + 1, bins[i+1]:bins[i+2]] = (bins[i+2] - np.arange(bins[i+1], bins[i + 2])  + 1e-12) \
                                                 / (bins[i + 2] - bins[i+1] + 1e-12)
    
     erb_filters[-1, bins[-2]:bins[-1]+1] = 1- erb_filters[-2, bins[-2]:bins[-1]+1]
    
     erb_filters = erb_filters[:, erb_subband_1:]
     return torch.from_numpy(np.abs(erb_filters))
    

    def bm(self, x): """x: (B,C,T,F)""" x_low = x[..., :self.erb_subband_1] x_high = self.erb_fc(x[..., self.erb_subband_1:]) return torch.cat([x_low, x_high], dim=-1)

    def bs(self, x_erb): """x: (B,C,T,F_erb)""" x_erb_low = x_erb[..., :self.erb_subband_1] x_erb_high = self.ierb_fc(x_erb[..., self.erb_subband_1:]) return torch.cat([x_erb_low, x_erb_high], dim=-1)

class SFE(nn.Module): """Subband Feature Extraction""" def init(self, kernel_size=3, stride=1): super().init() self.kernel_size = kernel_size self.unfold = nn.Unfold(kernel_size=(1,kernel_size), stride=(1, stride), padding=(0, (kernel_size-1)//2))

def forward(self, x):
    """x: (B,C,T,F)"""
    xs = self.unfold(x).reshape(x.shape[0], x.shape[1]*self.kernel_size, x.shape[2], x.shape[3])
    return xs

class StreamTRA(nn.Module): """Temporal Recurrent Attention""" def init(self, channels): super().init() self.att_gru = nn.GRU(channels, channels2, 1, batch_first=True) self.att_fc = nn.Linear(channels2, channels) self.att_act = nn.Sigmoid()

def forward(self, x, h_cache):
    """
    x: (B,C,T,F)
    h_cache: (1,B,C)
    """
    zt = torch.mean(x.pow(2), dim=-1)  # (B,C,T)
    at, h_cache = self.att_gru(zt.transpose(1,2), h_cache)
    at = self.att_fc(at).transpose(1,2)
    at = self.att_act(at)
    At = at[..., None]  # (B,C,T,1)

    return x * At, h_cache

class ConvBlock(nn.Module): def init(self, in_channels, out_channels, kernel_size, stride, padding, groups=1, use_deconv=False, is_last=False): super().init() conv_module = nn.ConvTranspose2d if use_deconv else nn.Conv2d self.conv = conv_module(in_channels, out_channels, kernel_size, stride, padding, groups=groups) self.bn = nn.BatchNorm2d(out_channels) self.act = nn.Tanh() if is_last else nn.PReLU() def forward(self, x): return self.act(self.bn(self.conv(x)))

class StreamGTConvBlock(nn.Module): """Group Temporal Convolution""" def init(self, in_channels, hidden_channels, kernel_size, stride, padding, dilation, use_deconv=False): super().init() self.use_deconv = use_deconv conv_module = nn.ConvTranspose2d if use_deconv else nn.Conv2d stream_conv_module = StreamConvTranspose2d if use_deconv else StreamConv2d

    self.sfe = SFE(kernel_size=3, stride=1)

    self.point_conv1 = conv_module(in_channels//2*3, hidden_channels, 1)
    self.point_bn1 = nn.BatchNorm2d(hidden_channels)
    self.point_act = nn.PReLU()

    self.depth_conv = stream_conv_module(hidden_channels, hidden_channels, kernel_size,
                                        stride=stride, padding=padding,
                                        dilation=dilation, groups=hidden_channels)
    self.depth_bn = nn.BatchNorm2d(hidden_channels)
    self.depth_act = nn.PReLU()

    self.point_conv2 = conv_module(hidden_channels, in_channels//2, 1)
    self.point_bn2 = nn.BatchNorm2d(in_channels//2)

    self.tra = StreamTRA(in_channels//2)

def shuffle(self, x1, x2):
    """x1, x2: (B,C,T,F)"""
    x = torch.stack([x1, x2], dim=1)
    x = x.transpose(1, 2).contiguous()  # (B,C,2,T,F)
    x = x.view(x.shape[0], -1, x.shape[3], x.shape[4])  # (B,2C,T,F)
    return x

def forward(self, x, conv_cache, tra_cache):
    """
    x: (B, C, T, F)
    conv_cache: (B, C, (kT-1)*dT, F)
    tra_cache: (1, B, C)
    """
    x1, x2 = x[:,:x.shape[1]//2], x[:, x.shape[1]//2:]

    x1 = self.sfe(x1)
    h1 = self.point_act(self.point_bn1(self.point_conv1(x1)))
    h1, conv_cache = self.depth_conv(h1, conv_cache)
    h1 = self.depth_act(self.depth_bn(h1))
    h1 = self.point_bn2(self.point_conv2(h1))

    h1, tra_cache = self.tra(h1, tra_cache)

    x =  self.shuffle(h1, x2)

    return x, conv_cache, tra_cache

class GRNN(nn.Module): """Grouped RNN""" def init(self, input_size, hidden_size, num_layers=1, batch_first=True, bidirectional=False): super().init() self.hidden_size = hidden_size self.num_layers = num_layers self.bidirectional = bidirectional self.rnn1 = nn.GRU(input_size//2, hidden_size//2, num_layers, batch_first=batch_first, bidirectional=bidirectional) self.rnn2 = nn.GRU(input_size//2, hidden_size//2, num_layers, batch_first=batch_first, bidirectional=bidirectional)

def forward(self, x, h=None):
    """
    x: (B, seq_length, input_size)
    h: (num_layers, B, hidden_size)
    """
    if h== None:
        if self.bidirectional:
            h = torch.zeros(self.num_layers*2, x.shape[0], self.hidden_size, device=x.device)
        else:
            h = torch.zeros(self.num_layers, x.shape[0], self.hidden_size, device=x.device)
    x1, x2 = torch.chunk(x, chunks=2, dim=-1)
    h1, h2 = torch.chunk(h, chunks=2, dim=-1)
    h1, h2 = h1.contiguous(), h2.contiguous()
    y1, h1 = self.rnn1(x1, h1)
    y2, h2 = self.rnn2(x2, h2)
    y = torch.cat([y1, y2], dim=-1)
    h = torch.cat([h1, h2], dim=-1)
    return y, h

class DPGRNN(nn.Module): """Grouped Dual-path RNN""" def init(self, input_size, width, hidden_size, **kwargs): super(DPGRNN, self).init(**kwargs) self.input_size = input_size self.width = width self.hidden_size = hidden_size

    self.intra_rnn = GRNN(input_size=input_size, hidden_size=hidden_size//2, bidirectional=True)
    self.intra_fc = nn.Linear(hidden_size, hidden_size)
    self.intra_ln = nn.LayerNorm((width, hidden_size), eps=1e-8)

    self.inter_rnn = GRNN(input_size=input_size, hidden_size=hidden_size, bidirectional=False)
    self.inter_fc = nn.Linear(hidden_size, hidden_size)
    self.inter_ln = nn.LayerNorm(((width, hidden_size)), eps=1e-8)

def forward(self, x, inter_cache):
    """
    x: (B, C, T, F)
    inter_cache: (1, BF, hidden_size)
    """
    ## Intra RNN
    x = x.permute(0, 2, 3, 1)  # (B,T,F,C)
    intra_x = x.reshape(x.shape[0] * x.shape[1], x.shape[2], x.shape[3])  # (B*T,F,C)
    intra_x = self.intra_rnn(intra_x)[0]  # (B*T,F,C)
    intra_x = self.intra_fc(intra_x)      # (B*T,F,C)
    intra_x = intra_x.reshape(x.shape[0], -1, self.width, self.hidden_size) # (B,T,F,C)
    intra_x = self.intra_ln(intra_x)
    intra_out = torch.add(x, intra_x)

    ## Inter RNN
    x = intra_out.permute(0,2,1,3)  # (B,F,T,C)
    inter_x = x.reshape(x.shape[0] * x.shape[1], x.shape[2], x.shape[3])
    inter_x, inter_cache = self.inter_rnn(inter_x, inter_cache)     # (B*F,T,C)
    inter_x = self.inter_fc(inter_x)      # (B*F,T,C)
    inter_x = inter_x.reshape(x.shape[0], self.width, -1, self.hidden_size) # (B,F,T,C)
    inter_x = inter_x.permute(0,2,1,3)   # (B,T,F,C)
    inter_x = self.inter_ln(inter_x)
    inter_out = torch.add(intra_out, inter_x)

    dual_out = inter_out.permute(0,3,1,2)  # (B,C,T,F)

    return dual_out, inter_cache

class StreamEncoder(nn.Module): def init(self): super().init() self.en_convs = nn.ModuleList([ ConvBlock(3*3, 16, (1,5), stride=(1,2), padding=(0,2), use_deconv=False, is_last=False), ConvBlock(16, 16, (1,5), stride=(1,2), padding=(0,2), groups=2, use_deconv=False, is_last=False), StreamGTConvBlock(16, 16, (3,3), stride=(1,1), padding=(0,1), dilation=(1,1), use_deconv=False), StreamGTConvBlock(16, 16, (3,3), stride=(1,1), padding=(0,1), dilation=(2,1), use_deconv=False), StreamGTConvBlock(16, 16, (3,3), stride=(1,1), padding=(0,1), dilation=(5,1), use_deconv=False) ])

def forward(self, x, conv_cache, tra_cache):
    """
    x: (B,C,T,F)
    conv_cache: (B,C, (kT-1)*8, F)
    tra_cache: (3,1,B,C)
    """
    en_outs = []
    for i in range(2):
        x = self.en_convs[i](x)
        en_outs.append(x)

    x, conv_cache[:,:, :2, :], tra_cache[0] = self.en_convs[2](x, conv_cache[:,:, :2, :], tra_cache[0]); en_outs.append(x)
    x, conv_cache[:,:, 2:6, :], tra_cache[1] = self.en_convs[3](x, conv_cache[:,:, 2:6, :], tra_cache[1]); en_outs.append(x)
    x, conv_cache[:,:, 6:16, :], tra_cache[2] = self.en_convs[4](x, conv_cache[:,:, 6:16, :], tra_cache[2]); en_outs.append(x)

    return x, en_outs, conv_cache, tra_cache

class StreamDecoder(nn.Module): def init(self): super().init() self.de_convs = nn.ModuleList([ StreamGTConvBlock(16, 16, (3,3), stride=(1,1), padding=(0,1), dilation=(5,1), use_deconv=True), StreamGTConvBlock(16, 16, (3,3), stride=(1,1), padding=(0,1), dilation=(2,1), use_deconv=True), StreamGTConvBlock(16, 16, (3,3), stride=(1,1), padding=(0,1), dilation=(1,1), use_deconv=True), ConvBlock(16, 16, (1,5), stride=(1,2), padding=(0,2), groups=2, use_deconv=True, is_last=False), ConvBlock(16, 2, (1,5), stride=(1,2), padding=(0,2), use_deconv=True, is_last=True) ])

def forward(self, x, en_outs, conv_cache, tra_cache):
    """
    x: (B,C,T,F)
    conv_cache: (B,C, (kT-1)*8, F)
    tra_cache: (3,1,B,C)
    """
    x, conv_cache[:,:, 6:16, :], tra_cache[0] = self.de_convs[0](x + en_outs[4], conv_cache[:,:, 6:16, :], tra_cache[0])
    x, conv_cache[:,:, 2:6, :], tra_cache[1] = self.de_convs[1](x + en_outs[3], conv_cache[:,:, 2:6, :], tra_cache[1])
    x, conv_cache[:,:, :2, :], tra_cache[2] = self.de_convs[2](x + en_outs[2], conv_cache[:,:, :2, :], tra_cache[2])

    for i in range(3, 5):
        x = self.de_convs[i](x + en_outs[4-i])
    return x, conv_cache, tra_cache

class Mask(nn.Module): """Complex Ratio Mask""" def init(self): super().init()

def forward(self, mask, spec):
    s_real = spec[:,0] * mask[:,0] - spec[:,1] * mask[:,1]
    s_imag = spec[:,1] * mask[:,0] + spec[:,0] * mask[:,1]
    s = torch.stack([s_real, s_imag], dim=1)  # (B,2,T,F)
    return s

class StreamGTCRN(nn.Module): def init(self): super().init() self.erb = ERB(65, 64) self.sfe = SFE(3, 1)

    self.encoder = StreamEncoder()

    self.dpgrnn1 = DPGRNN(16, 33, 16)
    self.dpgrnn2 = DPGRNN(16, 33, 16)

    self.decoder = StreamDecoder()

    self.mask = Mask()

def forward(self, spec, conv_cache, tra_cache, inter_cache):
    """
    spec: (B, F, T, 2) = (1, 257, 1, 2)
    conv_cache: [en_cache, de_cache], (2, B, C, 8(kT-1), F) = (2, 1, 16, 16, 33)
    tra_cache: [en_cache, de_cache], (2, 3, 1, B, C) = (2, 3, 1, 1, 16)
    inter_cache: [cache1, cache2], (2, 1, BF, C) = (2, 1, 33, 16)
    """
    spec_ref = spec  # (B,F,T,2)

    spec_real = spec[..., 0].permute(0,2,1)
    spec_imag = spec[..., 1].permute(0,2,1)
    spec_mag = torch.sqrt(spec_real**2 + spec_imag**2 + 1e-12)
    feat = torch.stack([spec_mag, spec_real, spec_imag], dim=1)  # (B,3,T,257)

    feat = self.erb.bm(feat)  # (B,3,T,129)
    feat = self.sfe(feat)     # (B,9,T,129)

    feat, en_outs, conv_cache[0], tra_cache[0] = self.encoder(feat, conv_cache[0], tra_cache[0])

    feat, inter_cache[0] = self.dpgrnn1(feat, inter_cache[0]) # (B,16,T,33)
    feat, inter_cache[1] = self.dpgrnn2(feat, inter_cache[1]) # (B,16,T,33)

    m_feat, conv_cache[1], tra_cache[1] = self.decoder(feat, en_outs, conv_cache[1], tra_cache[1])

    m = self.erb.bs(m_feat)

    spec_enh = self.mask(m, spec_ref.permute(0,3,2,1)) # (B,2,T,F)
    spec_enh = spec_enh.permute(0,3,2,1)  # (B,F,T,2)

    return spec_enh, conv_cache, tra_cache, inter_cache

if name == "main": import os import time import soundfile as sf from tqdm import tqdm from gtcrn import GTCRN from modules.convert import convert_to_stream

device = torch.device("cpu")

model = GTCRN().to(device).eval()
model.load_state_dict(torch.load('onnx_models/model_trained_on_dns3.tar', map_location=device)['model'])
stream_model = StreamGTCRN().to(device).eval()
convert_to_stream(stream_model, model)

"""Streaming Conversion"""
### offline inference
x = torch.from_numpy(sf.read('test_wavs/mix.wav', dtype='float32')[0])
x = torch.stft(x, 512, 256, 512, torch.hann_window(512).pow(0.5), return_complex=False)[None]
with torch.no_grad():
    y = model(x)
y = torch.istft(y, 512, 256, 512, torch.hann_window(512).pow(0.5)).detach().cpu().numpy()
sf.write('test_wavs/enh.wav', y.squeeze(), 16000)

### online (streaming) inference
conv_cache = torch.zeros(2, 1, 16, 16, 33).to(device)
tra_cache = torch.zeros(2, 3, 1, 1, 16).to(device)
inter_cache = torch.zeros(2, 1, 33, 16).to(device)
# ys = []
# times = []
# for i in tqdm(range(x.shape[2])):
#     xi = x[:,:,i:i+1]
#     tic = time.perf_counter()
#     with torch.no_grad():
#         yi, conv_cache, tra_cache, inter_cache = stream_model(xi, conv_cache, tra_cache, inter_cache)
#     toc = time.perf_counter()
#     times.append((toc-tic)*1000)
#     ys.append(yi)
# ys = torch.cat(ys, dim=2)

# ys = torch.istft(ys, 512, 256, 512, torch.hann_window(512).pow(0.5)).detach().cpu().numpy()
# sf.write('test_wavs/enh_stream.wav', ys.squeeze(), 16000)
# print(">>> inference time: mean: {:.1f}ms, max: {:.1f}ms, min: {:.1f}ms".format(sum(times)/len(times), max(times), min(times)))
# print(">>> Streaming error:", np.abs(y-ys).max())


"""ONNX Conversion"""
import os
import time
import onnx
import onnxruntime
from onnxsim import simplify
from librosa import istft

## convert to onnx
file = 'C:/Users/caipeng/source/gtcrn-main/stream/onnx_models/onnx_models/gtcrn.onnx'
if not os.path.exists(file):
    input = torch.randn(1, 257, 1, 2, device=device)
    torch.onnx.export(stream_model,
                    (input, conv_cache, tra_cache, inter_cache),
                    file,
                    input_names = ['mix', 'conv_cache', 'tra_cache', 'inter_cache'],
                    output_names = ['enh', 'conv_cache_out', 'tra_cache_out', 'inter_cache_out'],
                    opset_version=11,
                    verbose = False)

    onnx_model = onnx.load(file)
    onnx.checker.check_model(onnx_model)

# simplify onnx model
if not os.path.exists(file.split('.onnx')[0]+'_simple.onnx'):
    model_simp, check = simplify(onnx_model, perform_optimization=False)
    assert check, "Simplified ONNX model could not be validated"
    onnx.save(model_simp, file.split('.onnx')[0] + '_simple.onnx')


## run onnx model
# session = onnxruntime.InferenceSession(file, None, providers=['CPUExecutionProvider'])
session = onnxruntime.InferenceSession(file.split('.onnx')[0]+'_simple.onnx', None, providers=['CPUExecutionProvider'])
conv_cache = np.zeros([2, 1, 16, 16, 33],  dtype="float32")
tra_cache = np.zeros([2, 3, 1, 1, 16],  dtype="float32")
inter_cache = np.zeros([2, 1, 33, 16],  dtype="float32")

T_list = []
outputs = []

inputs = x.numpy()
for i in tqdm(range(inputs.shape[-2])):
    tic = time.perf_counter()

    out_i,  conv_cache, tra_cache, inter_cache \
            = session.run([], {'mix': inputs[..., i:i+1, :],
                'conv_cache': conv_cache,
                'tra_cache': tra_cache,
                'inter_cache': inter_cache})

    toc = time.perf_counter()
    T_list.append(toc-tic)
    outputs.append(out_i)

outputs = np.concatenate(outputs, axis=2)
enhanced = istft(outputs[...,0] + 1j * outputs[...,1], n_fft=512, hop_length=256, win_length=512, window=np.hanning(512)**0.5)
sf.write('test_wavs/enh_onnx.wav', enhanced.squeeze(), 16000)

print(">>> Onnx error:", np.abs(y - enhanced).max())
print(">>> inference time: mean: {:.1f}ms, max: {:.1f}ms, min: {:.1f}ms".format(1e3*np.mean(T_list), 1e3*np.max(T_list), 1e3*np.min(T_list)))
print(">>> RTF:", 1e3*np.mean(T_list) / 16)

how to reproduce | 复现步骤 | 再現方法

1.pnnx gtcrn_simple.onnx inputshape=[1,257,1,2],[2,1,16,16,33],[2,3,1,1,16],[2,1,33,16] 2. 3.

cp917 avatar May 21 '25 03:05 cp917

StreamGTCRN 是对这模型转换的时候出问题

cp917 avatar May 21 '25 03:05 cp917

我今天转的时候也出问题,是中间的DPGRNN层转换时出现的问题,转换后chunk操作的变到了第二个维度而不是最后一个维度。 x1, x2 = torch.chunk(x, chunks=2, dim=-1) h1, h2 = torch.chunk(h, chunks=2, dim=-1) 改下param文件就可以运行了。你还可以参考:https://github.com/Tencent/ncnn/issues/5941https://github.com/Tencent/ncnn/issues/5941 问题相似

jiayukk avatar Oct 11 '25 06:10 jiayukk