Retrieval-based-Voice-Conversion-WebUI icon indicating copy to clipboard operation
Retrieval-based-Voice-Conversion-WebUI copied to clipboard

Improve the model while maintaining compatibility

Open nadare881 opened this issue 2 years ago • 3 comments

There are some suggestions for more efficient training and inference without changing the structure of the model.

  • Added support for mixed precision in training and inference, and bfloat16 option

from torch.cuda.amp import autocast, GradScaler is used in training. This is a method that keeps the parameters of the model as float32 and only the vector of calculation as float16, so that both accuracy and efficiency are taken advantage of. By doing this during inference as well, it is possible to infer with a certain degree of precision even with the float16 option attached. Also, mixed precision often uses bfloat16 instead of float16. Allows you to choose between float32, float16, and bfloat16 when training. (However, until now, setting model to half reduced the memory capacity of the model, but this will disappear.)

  • preparation for parameter changes

Currently "segment_size" is set as small as 11520. For batches up to 4 seconds, GAN compares less than 0.25 seconds. However, in the current module, increasing this parameter causes an error. I fixed some variables to prevent this from happening. By correcting this, learning can proceed more efficiently.

Also, "upsample_rates": [10,6,2,2,2] must be a multiple of 2, so we do a large upsampling of 10 first. This is a problem from ConvTranspose1d, but you can use odd numbers by using output_padding. It doesn't affect the structure of the model, so I'll fix it.

             op = (k - u) % 2
             p = (k - u + op) // 2
             self.ups.append(
                 weight_norm(
                     ConvTranspose1d(
                         upsample_initial_channel // (2 ** i),
                         upsample_initial_channel // (2 ** (i + 1)),
                         k,
                         u,
                         padding=p,
                         output_padding=op
                     )
                 )
             )

If there is no problem with the above, I will create a PR.

nadare881 avatar Jun 07 '23 01:06 nadare881

1、How to detect whether the GPU supports bf16 or not? 2、 (1)11520: Sounds good. You can fix it. (2)Will odd number cause alignment problem?

RVC-Boss avatar Jun 07 '23 02:06 RVC-Boss

1、How to detect whether the GPU supports bf16 or not?

Mixed precision works on most hardware in TensorFlow. Probably similar for PyTorch. I don't think it's necessary for RVC itself to support this determination, although it can be slower in some environments.

Even on CPUs and older GPUs, where no speedup is expected, mixed precision APIs can still be used for unit testing, debugging, or just to try out the API. On CPUs, mixed precision will run significantly slower, however.

https://www.tensorflow.org/guide/mixed_precision#supported_hardware

  1. (2) Will odd numbert cause alignment problem?

With the current implementation, if you specify an odd number for the upsamling rate, the array lengths will not match between input and output. By adding the above modifications, the input and output array lengths will match.

nadare881 avatar Jun 07 '23 02:06 nadare881

1、How to detect whether the GPU supports bf16 or not?

Mixed precision works on most hardware in TensorFlow. Probably similar for PyTorch. I don't think it's necessary for RVC itself to support this determination, although it can be slower in some environments.

I was mistaken and it was currently already fixed at float16 mixed precision when training by default. It seems to work fine, so I won't add any additional options.

nadare881 avatar Jun 07 '23 11:06 nadare881

This issue was closed because it has been inactive for 15 days since being marked as stale.

github-actions[bot] avatar Apr 28 '24 04:04 github-actions[bot]