StyleTTS2 icon indicating copy to clipboard operation
StyleTTS2 copied to clipboard

Error Message After Using a fine tuned ASR Model

Open GUUser91 opened this issue 1 year ago • 4 comments

I get this error message after using a fine tuned ASR Model

---------------------------------------------------------------------------
RuntimeError                              Traceback (most recent call last)
Cell In[14], line 5
      4 try:
----> 5     model[key].load_state_dict(params[key])
      6 except:

File StyleTTS2/venv/lib/python3.11/site-packages/torch/nn/modules/module.py:2189, in Module.load_state_dict(self, state_dict, strict, assign)
   2188 if len(error_msgs) > 0:
-> 2189     raise RuntimeError('Error(s) in loading state_dict for {}:\n\t{}'.format(
   2190                        self.__class__.__name__, "\n\t".join(error_msgs)))
   2191 return _IncompatibleKeys(missing_keys, unexpected_keys)

RuntimeError: Error(s) in loading state_dict for ASRCNN:
	Missing key(s) in state_dict: "to_mfcc.dct_mat", "init_cnn.conv.weight", "init_cnn.conv.bias", "cnns.0.0.blocks.0.0.conv.weight", "cnns.0.0.blocks.0.0.conv.bias", "cnns.0.0.blocks.0.2.weight", "cnns.0.0.blocks.0.2.bias", "cnns.0.0.blocks.0.4.conv.weight", "cnns.0.0.blocks.0.4.conv.bias", "cnns.0.0.blocks.1.0.conv.weight", "cnns.0.0.blocks.1.0.conv.bias", "cnns.0.0.blocks.1.2.weight", "cnns.0.0.blocks.1.2.bias", "cnns.0.0.blocks.1.4.conv.weight", "cnns.0.0.blocks.1.4.conv.bias", "cnns.0.0.blocks.2.0.conv.weight", "cnns.0.0.blocks.2.0.conv.bias", "cnns.0.0.blocks.2.2.weight", "cnns.0.0.blocks.2.2.bias", "cnns.0.0.blocks.2.4.conv.weight", "cnns.0.0.blocks.2.4.conv.bias", "cnns.0.1.weight", "cnns.0.1.bias", "cnns.1.0.blocks.0.0.conv.weight", "cnns.1.0.blocks.0.0.conv.bias", "cnns.1.0.blocks.0.2.weight", "cnns.1.0.blocks.0.2.bias", "cnns.1.0.blocks.0.4.conv.weight", "cnns.1.0.blocks.0.4.conv.bias", "cnns.1.0.blocks.1.0.conv.weight", "cnns.1.0.blocks.1.0.conv.bias", "cnns.1.0.blocks.1.2.weight", "cnns.1.0.blocks.1.2.bias", "cnns.1.0.blocks.1.4.conv.weight", "cnns.1.0.blocks.1.4.conv.bias", "cnns.1.0.blocks.2.0.conv.weight", "cnns.1.0.blocks.2.0.conv.bias", "cnns.1.0.blocks.2.2.weight", "cnns.1.0.blocks.2.2.bias", "cnns.1.0.blocks.2.4.conv.weight", "cnns.1.0.blocks.2.4.conv.bias", "cnns.1.1.weight", "cnns.1.1.bias", "cnns.2.0.blocks.0.0.conv.weight", "cnns.2.0.blocks.0.0.conv.bias", "cnns.2.0.blocks.0.2.weight", "cnns.2.0.blocks.0.2.bias", "cnns.2.0.blocks.0.4.conv.weight", "cnns.2.0.blocks.0.4.conv.bias", "cnns.2.0.blocks.1.0.conv.weight", "cnns.2.0.blocks.1.0.conv.bias", "cnns.2.0.blocks.1.2.weight", "cnns.2.0.blocks.1.2.bias", "cnns.2.0.blocks.1.4.conv.weight", "cnns.2.0.blocks.1.4.conv.bias", "cnns.2.0.blocks.2.0.conv.weight", "cnns.2.0.blocks.2.0.conv.bias", "cnns.2.0.blocks.2.2.weight", "cnns.2.0.blocks.2.2.bias", "cnns.2.0.blocks.2.4.conv.weight", "cnns.2.0.blocks.2.4.conv.bias", "cnns.2.1.weight", "cnns.2.1.bias", "cnns.3.0.blocks.0.0.conv.weight", "cnns.3.0.blocks.0.0.conv.bias", "cnns.3.0.blocks.0.2.weight", "cnns.3.0.blocks.0.2.bias", "cnns.3.0.blocks.0.4.conv.weight", "cnns.3.0.blocks.0.4.conv.bias", "cnns.3.0.blocks.1.0.conv.weight", "cnns.3.0.blocks.1.0.conv.bias", "cnns.3.0.blocks.1.2.weight", "cnns.3.0.blocks.1.2.bias", "cnns.3.0.blocks.1.4.conv.weight", "cnns.3.0.blocks.1.4.conv.bias", "cnns.3.0.blocks.2.0.conv.weight", "cnns.3.0.blocks.2.0.conv.bias", "cnns.3.0.blocks.2.2.weight", "cnns.3.0.blocks.2.2.bias", "cnns.3.0.blocks.2.4.conv.weight", "cnns.3.0.blocks.2.4.conv.bias", "cnns.3.1.weight", "cnns.3.1.bias", "cnns.4.0.blocks.0.0.conv.weight", "cnns.4.0.blocks.0.0.conv.bias", "cnns.4.0.blocks.0.2.weight", "cnns.4.0.blocks.0.2.bias", "cnns.4.0.blocks.0.4.conv.weight", "cnns.4.0.blocks.0.4.conv.bias", "cnns.4.0.blocks.1.0.conv.weight", "cnns.4.0.blocks.1.0.conv.bias", "cnns.4.0.blocks.1.2.weight", "cnns.4.0.blocks.1.2.bias", "cnns.4.0.blocks.1.4.conv.weight", "cnns.4.0.blocks.1.4.conv.bias", "cnns.4.0.blocks.2.0.conv.weight", "cnns.4.0.blocks.2.0.conv.bias", "cnns.4.0.blocks.2.2.weight", "cnns.4.0.blocks.2.2.bias", "cnns.4.0.blocks.2.4.conv.weight", "cnns.4.0.blocks.2.4.conv.bias", "cnns.4.1.weight", "cnns.4.1.bias", "cnns.5.0.blocks.0.0.conv.weight", "cnns.5.0.blocks.0.0.conv.bias", "cnns.5.0.blocks.0.2.weight", "cnns.5.0.blocks.0.2.bias", "cnns.5.0.blocks.0.4.conv.weight", "cnns.5.0.blocks.0.4.conv.bias", "cnns.5.0.blocks.1.0.conv.weight", "cnns.5.0.blocks.1.0.conv.bias", "cnns.5.0.blocks.1.2.weight", "cnns.5.0.blocks.1.2.bias", "cnns.5.0.blocks.1.4.conv.weight", "cnns.5.0.blocks.1.4.conv.bias", "cnns.5.0.blocks.2.0.conv.weight", "cnns.5.0.blocks.2.0.conv.bias", "cnns.5.0.blocks.2.2.weight", "cnns.5.0.blocks.2.2.bias", "cnns.5.0.blocks.2.4.conv.weight", "cnns.5.0.blocks.2.4.conv.bias", "cnns.5.1.weight", "cnns.5.1.bias", "projection.conv.weight", "projection.conv.bias", "ctc_linear.0.linear_layer.weight", "ctc_linear.0.linear_layer.bias", "ctc_linear.2.linear_layer.weight", "ctc_linear.2.linear_layer.bias", "asr_s2s.embedding.weight", "asr_s2s.project_to_n_symbols.weight", "asr_s2s.project_to_n_symbols.bias", "asr_s2s.attention_layer.query_layer.linear_layer.weight", "asr_s2s.attention_layer.memory_layer.linear_layer.weight", "asr_s2s.attention_layer.v.linear_layer.weight", "asr_s2s.attention_layer.location_layer.location_conv.conv.weight", "asr_s2s.attention_layer.location_layer.location_dense.linear_layer.weight", "asr_s2s.decoder_rnn.weight_ih", "asr_s2s.decoder_rnn.weight_hh", "asr_s2s.decoder_rnn.bias_ih", "asr_s2s.decoder_rnn.bias_hh", "asr_s2s.project_to_hidden.0.linear_layer.weight", "asr_s2s.project_to_hidden.0.linear_layer.bias". 
	Unexpected key(s) in state_dict: "module.to_mfcc.dct_mat", "module.init_cnn.conv.weight", "module.init_cnn.conv.bias", "module.cnns.0.0.blocks.0.0.conv.weight", "module.cnns.0.0.blocks.0.0.conv.bias", "module.cnns.0.0.blocks.0.2.weight", "module.cnns.0.0.blocks.0.2.bias", "module.cnns.0.0.blocks.0.4.conv.weight", "module.cnns.0.0.blocks.0.4.conv.bias", "module.cnns.0.0.blocks.1.0.conv.weight", "module.cnns.0.0.blocks.1.0.conv.bias", "module.cnns.0.0.blocks.1.2.weight", "module.cnns.0.0.blocks.1.2.bias", "module.cnns.0.0.blocks.1.4.conv.weight", "module.cnns.0.0.blocks.1.4.conv.bias", "module.cnns.0.0.blocks.2.0.conv.weight", "module.cnns.0.0.blocks.2.0.conv.bias", "module.cnns.0.0.blocks.2.2.weight", "module.cnns.0.0.blocks.2.2.bias", "module.cnns.0.0.blocks.2.4.conv.weight", "module.cnns.0.0.blocks.2.4.conv.bias", "module.cnns.0.1.weight", "module.cnns.0.1.bias", "module.cnns.1.0.blocks.0.0.conv.weight", "module.cnns.1.0.blocks.0.0.conv.bias", "module.cnns.1.0.blocks.0.2.weight", "module.cnns.1.0.blocks.0.2.bias", "module.cnns.1.0.blocks.0.4.conv.weight", "module.cnns.1.0.blocks.0.4.conv.bias", "module.cnns.1.0.blocks.1.0.conv.weight", "module.cnns.1.0.blocks.1.0.conv.bias", "module.cnns.1.0.blocks.1.2.weight", "module.cnns.1.0.blocks.1.2.bias", "module.cnns.1.0.blocks.1.4.conv.weight", "module.cnns.1.0.blocks.1.4.conv.bias", "module.cnns.1.0.blocks.2.0.conv.weight", "module.cnns.1.0.blocks.2.0.conv.bias", "module.cnns.1.0.blocks.2.2.weight", "module.cnns.1.0.blocks.2.2.bias", "module.cnns.1.0.blocks.2.4.conv.weight", "module.cnns.1.0.blocks.2.4.conv.bias", "module.cnns.1.1.weight", "module.cnns.1.1.bias", "module.cnns.2.0.blocks.0.0.conv.weight", "module.cnns.2.0.blocks.0.0.conv.bias", "module.cnns.2.0.blocks.0.2.weight", "module.cnns.2.0.blocks.0.2.bias", "module.cnns.2.0.blocks.0.4.conv.weight", "module.cnns.2.0.blocks.0.4.conv.bias", "module.cnns.2.0.blocks.1.0.conv.weight", "module.cnns.2.0.blocks.1.0.conv.bias", "module.cnns.2.0.blocks.1.2.weight", "module.cnns.2.0.blocks.1.2.bias", "module.cnns.2.0.blocks.1.4.conv.weight", "module.cnns.2.0.blocks.1.4.conv.bias", "module.cnns.2.0.blocks.2.0.conv.weight", "module.cnns.2.0.blocks.2.0.conv.bias", "module.cnns.2.0.blocks.2.2.weight", "module.cnns.2.0.blocks.2.2.bias", "module.cnns.2.0.blocks.2.4.conv.weight", "module.cnns.2.0.blocks.2.4.conv.bias", "module.cnns.2.1.weight", "module.cnns.2.1.bias", "module.cnns.3.0.blocks.0.0.conv.weight", "module.cnns.3.0.blocks.0.0.conv.bias", "module.cnns.3.0.blocks.0.2.weight", "module.cnns.3.0.blocks.0.2.bias", "module.cnns.3.0.blocks.0.4.conv.weight", "module.cnns.3.0.blocks.0.4.conv.bias", "module.cnns.3.0.blocks.1.0.conv.weight", "module.cnns.3.0.blocks.1.0.conv.bias", "module.cnns.3.0.blocks.1.2.weight", "module.cnns.3.0.blocks.1.2.bias", "module.cnns.3.0.blocks.1.4.conv.weight", "module.cnns.3.0.blocks.1.4.conv.bias", "module.cnns.3.0.blocks.2.0.conv.weight", "module.cnns.3.0.blocks.2.0.conv.bias", "module.cnns.3.0.blocks.2.2.weight", "module.cnns.3.0.blocks.2.2.bias", "module.cnns.3.0.blocks.2.4.conv.weight", "module.cnns.3.0.blocks.2.4.conv.bias", "module.cnns.3.1.weight", "module.cnns.3.1.bias", "module.cnns.4.0.blocks.0.0.conv.weight", "module.cnns.4.0.blocks.0.0.conv.bias", "module.cnns.4.0.blocks.0.2.weight", "module.cnns.4.0.blocks.0.2.bias", "module.cnns.4.0.blocks.0.4.conv.weight", "module.cnns.4.0.blocks.0.4.conv.bias", "module.cnns.4.0.blocks.1.0.conv.weight", "module.cnns.4.0.blocks.1.0.conv.bias", "module.cnns.4.0.blocks.1.2.weight", "module.cnns.4.0.blocks.1.2.bias", "module.cnns.4.0.blocks.1.4.conv.weight", "module.cnns.4.0.blocks.1.4.conv.bias", "module.cnns.4.0.blocks.2.0.conv.weight", "module.cnns.4.0.blocks.2.0.conv.bias", "module.cnns.4.0.blocks.2.2.weight", "module.cnns.4.0.blocks.2.2.bias", "module.cnns.4.0.blocks.2.4.conv.weight", "module.cnns.4.0.blocks.2.4.conv.bias", "module.cnns.4.1.weight", "module.cnns.4.1.bias", "module.cnns.5.0.blocks.0.0.conv.weight", "module.cnns.5.0.blocks.0.0.conv.bias", "module.cnns.5.0.blocks.0.2.weight", "module.cnns.5.0.blocks.0.2.bias", "module.cnns.5.0.blocks.0.4.conv.weight", "module.cnns.5.0.blocks.0.4.conv.bias", "module.cnns.5.0.blocks.1.0.conv.weight", "module.cnns.5.0.blocks.1.0.conv.bias", "module.cnns.5.0.blocks.1.2.weight", "module.cnns.5.0.blocks.1.2.bias", "module.cnns.5.0.blocks.1.4.conv.weight", "module.cnns.5.0.blocks.1.4.conv.bias", "module.cnns.5.0.blocks.2.0.conv.weight", "module.cnns.5.0.blocks.2.0.conv.bias", "module.cnns.5.0.blocks.2.2.weight", "module.cnns.5.0.blocks.2.2.bias", "module.cnns.5.0.blocks.2.4.conv.weight", "module.cnns.5.0.blocks.2.4.conv.bias", "module.cnns.5.1.weight", "module.cnns.5.1.bias", "module.projection.conv.weight", "module.projection.conv.bias", "module.ctc_linear.0.linear_layer.weight", "module.ctc_linear.0.linear_layer.bias", "module.ctc_linear.2.linear_layer.weight", "module.ctc_linear.2.linear_layer.bias", "module.asr_s2s.embedding.weight", "module.asr_s2s.project_to_n_symbols.weight", "module.asr_s2s.project_to_n_symbols.bias", "module.asr_s2s.attention_layer.query_layer.linear_layer.weight", "module.asr_s2s.attention_layer.memory_layer.linear_layer.weight", "module.asr_s2s.attention_layer.v.linear_layer.weight", "module.asr_s2s.attention_layer.location_layer.location_conv.conv.weight", "module.asr_s2s.attention_layer.location_layer.location_dense.linear_layer.weight", "module.asr_s2s.decoder_rnn.weight_ih", "module.asr_s2s.decoder_rnn.weight_hh", "module.asr_s2s.decoder_rnn.bias_ih", "module.asr_s2s.decoder_rnn.bias_hh", "module.asr_s2s.project_to_hidden.0.linear_layer.weight", "module.asr_s2s.project_to_hidden.0.linear_layer.bias". 

During handling of the above exception, another exception occurred:

RuntimeError                              Traceback (most recent call last)
Cell In[14], line 14
     12                 new_state_dict[name] = v
     13             # load params
---> 14             model[key].load_state_dict(new_state_dict, strict=False)
     15 #             except:
     16 #                 _load(params[key], model[key])
     17 _ = [model[key].eval() for key in model]

File StyleTTS2/venv/lib/python3.11/site-packages/torch/nn/modules/module.py:2189, in Module.load_state_dict(self, state_dict, strict, assign)
   2184         error_msgs.insert(
   2185             0, 'Missing key(s) in state_dict: {}. '.format(
   2186                 ', '.join(f'"{k}"' for k in missing_keys)))
   2188 if len(error_msgs) > 0:
-> 2189     raise RuntimeError('Error(s) in loading state_dict for {}:\n\t{}'.format(
   2190                        self.__class__.__name__, "\n\t".join(error_msgs)))
   2191 return _IncompatibleKeys(missing_keys, unexpected_keys)

RuntimeError: Error(s) in loading state_dict for ASRCNN:
	size mismatch for ctc_linear.2.linear_layer.weight: copying a param with shape torch.Size([178, 256]) from checkpoint, the shape in current model is torch.Size([80, 256]).
	size mismatch for ctc_linear.2.linear_layer.bias: copying a param with shape torch.Size([178]) from checkpoint, the shape in current model is torch.Size([80]).
	size mismatch for asr_s2s.embedding.weight: copying a param with shape torch.Size([178, 512]) from checkpoint, the shape in current model is torch.Size([80, 256]).
	size mismatch for asr_s2s.project_to_n_symbols.weight: copying a param with shape torch.Size([178, 128]) from checkpoint, the shape in current model is torch.Size([80, 128]).
	size mismatch for asr_s2s.project_to_n_symbols.bias: copying a param with shape torch.Size([178]) from checkpoint, the shape in current model is torch.Size([80]).
	size mismatch for asr_s2s.decoder_rnn.weight_ih: copying a param with shape torch.Size([512, 640]) from checkpoint, the shape in current model is torch.Size([512, 384]).

GUUser91 avatar Jun 12 '24 12:06 GUUser91

@GUUser91 Were you able to solve this? I have the same error.

MARafey avatar Dec 29 '24 10:12 MARafey

@MARafey No.

GUUser91 avatar Dec 29 '24 18:12 GUUser91

@GUUser91 I was able to solve the issue it seems that the parameters in my config file for fine tunning didn't match the ones in the utils folder

MARafey avatar Dec 30 '24 16:12 MARafey

I ran into the same issue, so rather than training an ASR from scratch, I just fine-tuned off the checkpoint in the Utils folder and ensured that I used the same configuration as in the pre-trained networks. There's a mismatch between the config in the AuxiliaryASR repo and the StyleTTS2/Utils/ASR folder.

DrBrule avatar Jan 17 '25 22:01 DrBrule