FreeVC icon indicating copy to clipboard operation
FreeVC copied to clipboard

I have a question about your WER/CER results in the paper.

Open SeongYeonPark opened this issue 3 years ago • 2 comments

In your paper, you report WER and CER results of about 4.23% and 1.46%. Also, you mentioned that you used https://huggingface.co/facebook/hubert-large-ls960-ft as the ASR model.

But, when using the same ASR model on ground truth VCTK utterances, I get WER/CER of about 6.43% and 1.95%. So I assume our codes for measuring WER/CER are different.

Could you share the code for evaluating WER/CER? Or at least a code fragment of it? Thank you.

SeongYeonPark avatar Dec 06 '22 01:12 SeongYeonPark

"Word error rate (WER) and character error rate (CER) between source and converted speech", I used the transcriptions of source speech obtained by the ASR model as the ground truth.

get_gt.py

from transformers import Wav2Vec2Processor, HubertForCTC
import os
import argparse
import torch
import librosa
from tqdm import tqdm
from glob import glob

if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument("--txtpath", type=str, default="gt.txt", help="path to tgt txt file")
    parser.add_argument("--wavdir", type=str, default="SOURCE")
    args = parser.parse_args()

    # load model and processor
    model_text = HubertForCTC.from_pretrained("facebook/hubert-large-ls960-ft").cuda()
    processor_text = Wav2Vec2Processor.from_pretrained("facebook/hubert-large-ls960-ft")
       
    # get transcriptions
    wavs = glob(f'{args.wavdir}/*.wav')
    wavs.sort()
    with open(f"{args.txtpath}", "w") as f:
        for path in tqdm(wavs):
            wav = [librosa.load(path, sr=16000)[0]]
            input_values = processor_text(wav, return_tensors="pt").input_values.cuda() # text
            logits = model_text(input_values).logits
            predicted_ids = torch.argmax(logits, dim=-1)
            text = processor_text.batch_decode(predicted_ids)[0]
            f.write(f"{path}|{text}\n")

wer.py

from transformers import Wav2Vec2Processor, HubertForCTC
import os
import argparse
import torch
import librosa
from tqdm import tqdm
from glob import glob
from jiwer import wer, cer

if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument("--wavdir", type=str, default="PROPOSED")
    parser.add_argument("--outdir", type=str, default="result", help="path to output dir")
    parser.add_argument("--use_cuda", default=False, action="store_true")
    args = parser.parse_args()
    
    os.makedirs(args.outdir, exist_ok=True)

    # load model and processor
    model = HubertForCTC.from_pretrained("facebook/hubert-large-ls960-ft")
    if args.use_cuda:
        model = model.cuda()
    processor = Wav2Vec2Processor.from_pretrained("facebook/hubert-large-ls960-ft")
    
    # gt
    gt_dict = {}
    with open("gt.txt", "r") as f:
        for line in f.readlines():
            path, text = line.strip().split("|")
            title = os.path.basename(path)[:-4]
            gt_dict[title] = text
    
    # get transcriptions
    wavs = glob(f'{args.wavdir}/*.wav')
    wavs.sort()
    trans_dict = {}
    
    with open(f"{args.outdir}/text.txt", "w") as f:
        for path in tqdm(wavs):
            wav = [librosa.load(path, sr=16000)[0]]
            input_values = processor(wav, return_tensors="pt").input_values
            if args.use_cuda:
                input_values = input_values.cuda()
            logits = model(input_values).logits
            predicted_ids = torch.argmax(logits, dim=-1)
            text = processor.batch_decode(predicted_ids)[0]
            f.write(f"{path}|{text}\n")
            title = os.path.basename(path)[:-4]
            trans_dict[title] = text
    
    # calc
    gts, trans = [], []
    for key in trans_dict.keys():
        text = trans_dict[key]
        trans.append(text)
        gttext = gt_dict[key.split("-")[0]]
        gts.append(gttext)
    
    wer = wer(gts, trans)
    cer = cer(gts, trans)
    with open(f"{args.outdir}/wer.txt", "w") as f:
        f.write(f"wer: {wer}\n")
        f.write(f"cer: {cer}\n")
    print("WER:", wer)
    print("CER:", cer)

OlaWod avatar Dec 06 '22 02:12 OlaWod

Thank you for answering!

SeongYeonPark avatar Dec 06 '22 04:12 SeongYeonPark