Retrieval-based-Voice-Conversion-WebUI icon indicating copy to clipboard operation
Retrieval-based-Voice-Conversion-WebUI copied to clipboard

Suggestion: Drop `.pth` models and move to safetensors

Open 413x1nkp opened this issue 1 year ago • 10 comments

This idea has been around for quite some time, the main reasoning is that .pth files are fundamentally unsafe, allowing for remote code execution if injected. This brings us to the solution the original repo had - usintg SHA256 for the main weights to ensure that at least the official weights haven't been tampered with. However, this problem has a much simpler fix: Using safetensors instead of raw weights, pth. This way no validity check is even required, therefore startup time is sped up significantly. This also improves security as users will not be able to inject any malicious code in the raw weights.

I, sadly, don't remember who was the author of the original idea; however i believe it was notfelt from discord server AIHub

413x1nkp avatar Jun 12 '24 11:06 413x1nkp

There're many outside pths like uvr5. We can drop those support first, then considering the safetensor implementation.

P.S. The hash is not only for the safety, but also for the models are large. In case of the models are broken during downloading, a check at start time is necessary. Well, we can wrap a lazy-check to check the hash when the model is to be loaded.

fumiama avatar Jun 12 '24 12:06 fumiama

There're many outside pths like uvr5. We can drop those support first, then considering the safetensor implementation.

That sounds like a good plan!

P.S. The hash is not only for the safety, but also for the models are large. In case of the models are broken during downloading, a check at start time is necessary. Well, we can wrap a lazy-check to check the hash when the model is to be loaded.

Huh, I see. Lazy check does sound like the best solution in that case.

413x1nkp avatar Jun 12 '24 12:06 413x1nkp

Lazy check does sound like the best solution in that case.

Check just one pth can only spend less than 1s, maybe acceptable.

fumiama avatar Jun 12 '24 12:06 fumiama

Check just one pth can only spend less than 1s, maybe acceptable.

I mean, the check is simply just "Calculate hash of the pth and compare to the stored hash", this can be done in C for a check that takes <0.1s to do. Not sure if this is a good solution right now, since RVC is not yet structured like a module, so the hash checker has to be built externally and used as a python module (at least until the RVC is made into a module)

I can start working on the hash checker in C and bind it to python functions for easy access as a module.

413x1nkp avatar Jun 12 '24 12:06 413x1nkp

Well, I don't know you know it or not, the python standard hashlib is not a pure python one, but written in c, if my memory is correct.

fumiama avatar Jun 12 '24 12:06 fumiama

Well, I don't know you know it or not, the python standard hashlib is not a pure python one, but written in c, if my memory is correct.

Huh, I didn't know that. However, I meant more of fully RVC-specific implementation, with checking all of the files against their hashes, instead of using hashlib for all of the files independently.

Since the number of files is known beforehand and all of their hashes are accessible, this might improve performance since we'll do a single call to C for hash-checking of all of the files.

So, if we don't hardcode the values in, it'll be:

import RVC_Hash
def check_all_assets(update=False) -> bool:
    BASE_DIR = Path(__file__).resolve().parent.parent.parent

    logger.info("checking hubret & rmvpe...")

    if not check_model(
        BASE_DIR / "assets" / "hubert",
        "hubert_base.pt",
        os.environ["sha256_hubert_base_pt"],
        update,
    ):
        return False
    if not check_model(
        BASE_DIR / "assets" / "rmvpe",
        "rmvpe.pt",
        os.environ["sha256_rmvpe_pt"],
        update,
    ):
        return False
    if not check_model(
        BASE_DIR / "assets" / "rmvpe",
        "rmvpe.onnx",
        os.environ["sha256_rmvpe_onnx"],
        update,
    ):
        return False

    rvc_models_dir = BASE_DIR / "assets" / "pretrained"
    logger.info("checking pretrained models...")
    model_names = [
        "D32k.pth",
        "D40k.pth",
        "D48k.pth",
        "G32k.pth",
        "G40k.pth",
        "G48k.pth",
        "f0D32k.pth",
        "f0D40k.pth",
        "f0D48k.pth",
        "f0G32k.pth",
        "f0G40k.pth",
        "f0G48k.pth",
    ]
	RVC_Hash.check_hashes(model_names)

Or in case if we hardcode the values in:

import RVC_Hash
def check_all_assets(update=False) -> bool:
    BASE_DIR = Path(__file__).resolve().parent.parent.parent

    logger.info("checking hubret & rmvpe...")

    if not check_model(
        BASE_DIR / "assets" / "hubert",
        "hubert_base.pt",
        os.environ["sha256_hubert_base_pt"],
        update,
    ):
        return False
    if not check_model(
        BASE_DIR / "assets" / "rmvpe",
        "rmvpe.pt",
        os.environ["sha256_rmvpe_pt"],
        update,
    ):
        return False
    if not check_model(
        BASE_DIR / "assets" / "rmvpe",
        "rmvpe.onnx",
        os.environ["sha256_rmvpe_onnx"],
        update,
    ):
        return False

    rvc_models_dir = BASE_DIR / "assets" / "pretrained"
    logger.info("checking pretrained models...")
	RVC_Hash.check_hashes()

Instead of:

def check_all_assets(update=False) -> bool:
    BASE_DIR = Path(__file__).resolve().parent.parent.parent

    logger.info("checking hubret & rmvpe...")

    if not check_model(
        BASE_DIR / "assets" / "hubert",
        "hubert_base.pt",
        os.environ["sha256_hubert_base_pt"],
        update,
    ):
        return False
    if not check_model(
        BASE_DIR / "assets" / "rmvpe",
        "rmvpe.pt",
        os.environ["sha256_rmvpe_pt"],
        update,
    ):
        return False
    if not check_model(
        BASE_DIR / "assets" / "rmvpe",
        "rmvpe.onnx",
        os.environ["sha256_rmvpe_onnx"],
        update,
    ):
        return False

    rvc_models_dir = BASE_DIR / "assets" / "pretrained"
    logger.info("checking pretrained models...")
    model_names = [
        "D32k.pth",
        "D40k.pth",
        "D48k.pth",
        "G32k.pth",
        "G40k.pth",
        "G48k.pth",
        "f0D32k.pth",
        "f0D40k.pth",
        "f0D48k.pth",
        "f0G32k.pth",
        "f0G40k.pth",
        "f0G48k.pth",
    ]
    for model in model_names:
        menv = model.replace(".", "_")
        if not check_model(
            rvc_models_dir, model, os.environ[f"sha256_v1_{menv}"], update
        ):
            return False

As Python's for loops might have a small overhead unlike C-loops.

413x1nkp avatar Jun 12 '24 12:06 413x1nkp

Well, if you want to write a specialized program to do this stuff, I will not refuse it but it should be a platform-independent program, which can be run under Windows, Linux, MacOS, etc. and with the architecture of amd64, arm64, etc.

fumiama avatar Jun 12 '24 14:06 fumiama

Well, if you want to write a specialized program to do this stuff, I will not refuse it but it should be a platform-independent program, which can be run under Windows, Linux, MacOS, etc. and with the architecture of amd64, arm64, etc.

Noted! Will also attempt to make it work under both little endian and big endian!

413x1nkp avatar Jun 12 '24 14:06 413x1nkp

some interesting info here https://huggingface.co/docs/hub/security-pickle

blaisewf avatar Jun 12 '24 20:06 blaisewf

Alternatively, you can explicitly call torch.load with the argument weights_only=True. This will be the default in future PyTorch releases.

TheTrustedComputer avatar Aug 27 '24 20:08 TheTrustedComputer