FastSpeech2 icon indicating copy to clipboard operation
FastSpeech2 copied to clipboard

How FastSpeech2 export onnx ?

Open youngstu opened this issue 4 years ago • 7 comments

How FastSpeech2 export onnx ? I want to export onnx format and export tflite to deploy.

youngstu avatar Sep 03 '21 04:09 youngstu

I solved it. The problem that causes my error is that ONNX doesn't support torch.bucketsize(). So I rewrote the bucketsize function according to https://github.com/pytorch/pytorch/issues/7284. Add this code in model/modules.py

def bucketize(self, tensor, bucket_boundaries):
    result = torch.zeros_like(tensor, dtype=torch.int32)
    for boundary in bucket_boundaries:
        result += (tensor > boundary).int()
    return result.long()

Replace all torch.bucketsize with self.bucketsize.

For input, my code is

input_names = ['speakers', 'texts','src_lens', 'max_src_len']
output_names = ['output', 'postnet_output', 'p_predictions', 'e_predictions', 'log_d_predictions', 'd_rounded', 'src_masks', 'mel_masks', 'src_lens', 'mel_lens']
dynamic_axes = {
    "texts": {1: "texts_len"}, 
    "output": {1: "output_len"}, 
    "postnet_output": {1: "postnet_output_len"}, 
    "p_predictions": {1: "p_predictions_len"}, 
    "e_predictions": {1: "e_predictions_len"}, 
    "log_d_predictions": {1: "log_d_predictions_len"}, 
    "d_rounded": {1: "d_rounded_len"}, 
    "src_masks": {1: "src_masks_len"}
}

dummy_input_1 = batch[2]
dummy_input_2 = batch[3]
dummy_input_3 = batch[4]
dummy_input_4 = batch[5]
dummy_input_4 = torch.from_numpy(np.array(dummy_input_4)).to(device)
torch.onnx.export(model, args=(dummy_input_1, dummy_input_2, dummy_input_3, dummy_input_4), f="FastSpeech.onnx", input_names=input_names, output_names=output_names, dynamic_axes=dynamic_axes, opset_version=11)
 

jerryuhoo avatar Oct 27 '21 02:10 jerryuhoo

mark

Pydataman avatar Dec 29 '21 06:12 Pydataman

@jerryuhoo Excuse me , How did you solve the dynamic input in inference ? I also geiv the dynamic_axes , But I can't get dynamic input in inference. All the output is same with output in what I covert to Onnx model. It can't get dynamic output. Here is my detail in this Link

Tian14267 avatar Jan 18 '22 07:01 Tian14267

@jerryuhoo I got same problem as yours. Even the texts and text_lens exported as dynamic axis, but somehow it can not fully traced as dynamic, I can make it pass onnxruntime only when set input shape same as export onnx.

python -m onnxsim fastspeech2.onnx fastspeech2_sim.onnx --dynamic-input-shape --input-shape tones:1,58,8 texts:1,58 text_lens:1
Simplifying...
Checking 0/3...
Checking 1/3...
Checking 2/3...
Ok!

so I think the solution here would be forcely padding input same as your input size and make input fixed.

But in this way, don't know how to cut the postnet output according to input text real lenght.

lucasjinreal avatar Jan 24 '22 02:01 lucasjinreal

@jerryuhoo I got same problem as yours. Even the texts and text_lens exported as dynamic axis, but somehow it can not fully traced as dynamic, I can make it pass onnxruntime only when set input shape same as export onnx.

python -m onnxsim fastspeech2.onnx fastspeech2_sim.onnx --dynamic-input-shape --input-shape tones:1,58,8 texts:1,58 text_lens:1
Simplifying...
Checking 0/3...
Checking 1/3...
Checking 2/3...
Ok!

so I think the solution here would be forcely padding input same as your input size and make input fixed.

But in this way, don't know how to cut the postnet output according to input text real lenght.

lucasjinreal avatar Jan 24 '22 02:01 lucasjinreal

@jinfagang How do you convert torch.linspace in variance predictor?I got error msg "Exporting the operator linspace to ONNX opset version 11 is not supported". My torch version is 1.7.0.

OnceJune avatar Feb 15 '22 01:02 OnceJune

@jinfagang I think the .item() function make it constant. As warning like below: TracerWarning: Converting a tensor to a Python number might cause the trace to be incorrect. We can't record the data flow of Python values, so this value will be treated as a constant in the future. This means that the trace might not generalize to other inputs! max_len = torch.max(lengths).item()

zhanminmin avatar Mar 17 '22 12:03 zhanminmin