NeMo icon indicating copy to clipboard operation
NeMo copied to clipboard

Discrepancy in custom transcribe pipeline vs. `model.transcribe()` for QuartzNet model

Open diarray-hub opened this issue 10 months ago • 10 comments

Describe the bug

I’m attempting to replicate the model.transcribe() method with a custom inference pipeline (for eventual mobile deployment in a Flutter app). However, I’m observing a large discrepancy between the outputs of my custom pipeline and the official model.transcribe() method on the exact same audio file. Specifically:

  1. The transcription text differs significantly.
  2. The Hypothesis.y_sequence shapes and values are drastically different. In the official transcribe pipeline, y_sequence is continuous of shape [T, D], while in my pipeline it ends up as discrete indices of shape [T].

I’ve tried calling the same model.decoding.ctc_decoder_predictions_tensor() function and ensuring all arguments match. I also verified that no additional data augmentation or special channel selection is being applied. Yet the results remain inconsistent.

We suspect there’s some hidden post-processing step (beyond decode_hypothesis()) or a difference in how transcribe() manages decoding configuration that we’re not replicating, but we can’t pinpoint where it’s happening.


Steps/Code to reproduce bug

Below is a minimal snippet of how I’m trying to replicate transcribe():

def load_audio(filepath: str, sample_rate=16000) -> Tuple[torch.Tensor, torch.Tensor]:
    audio_np, sr = librosa.load(filepath, sr=sample_rate)
    audio_tensor = torch.tensor(audio_np, dtype=torch.float32).unsqueeze(0)
    length_tensor = torch.tensor([audio_tensor.size(1)], dtype=torch.long)
    return audio_tensor, length_tensor

def transcribe_inference(model, filepath: str, return_hypotheses: bool = False):
    # 1) Load the audio
    audio_tensor, length_tensor = load_audio(filepath)

    # 2) Forward pass
    log_probs, encoded_len, predictions = model.forward(
        input_signal=audio_tensor, input_signal_length=length_tensor
    )

    # 3) Use the same decoding function as Nemo
    hypotheses, _ = model.decoding.ctc_decoder_predictions_tensor(
        decoder_outputs=log_probs,
        decoder_lengths=encoded_len,
        return_hypotheses=return_hypotheses,
    )
    return hypotheses, predictions

# Example usage:
transcriptions, predictions = transcribe_inference(model, "some_audio.wav", return_hypotheses=True)
print("Custom pipeline transcription:", transcriptions[0].text)

And here’s how I call the official method:

result = model.transcribe(["some_audio.wav"], return_hypotheses=True)
print("Official model.transcribe() result:", result[0].text)

Discrepancy:

  • The official model.transcribe() returns a Hypothesis object with y_sequence shaped [429, 46] (continuous), plus a very accurate text.
  • My pipeline yields y_sequence shaped [429] (discrete indices), and a less accurate text.

I tried enabling timestamps, verifying the change_decoding_strategy(), checking if dither or augmentation is disabled, etc. No luck so far. It seems the pipeline is missing an internal step that transcribe() does after forward() but before returning the final Hypothesis.


Expected behavior

I expect that by calling the same decoding function (model.decoding.ctc_decoder_predictions_tensor) on the same input, I would get identical or near‑identical text/hypotheses as model.transcribe(). Instead, I’m seeing major differences in both text and the shape/values of Hypothesis.y_sequence.


Environment overview

  • Environment location: Bare-metal
  • Method of NeMo install: pip install nemo_toolkit['asr'] (version 2.1.0)
  • PyTorch version: 2.5.1
  • Python version: 3.10
  • OS: Ubuntu 22.04
  • GPU model: None (using CPU)

Additional context

  • My ultimate goal is to deploy the fine‑tuned QuartzNet model on mobile (Flutter). TorchScript attempts failed with dithering errors, so I pivoted to ONNX. However, I do understand that we'll have like a tone of code to write in dart to replicate faithfully the model.transcribe method (preprocessing, post processing). But before doing that we wanted to test to replicate transcribe method in python at a level closer to nemo, to know exactly which pre/post steps to port into Dart.
  • The main confusion: Why does transcribe() produce a [T, D] y_sequence with continuous values, while my pipeline produces a [T] discrete sequence even though I call the same decoding function?
  • Possibly some hidden step in _transcribe_output_processing() or decode_hypothesis() is not triggered in my pipeline. But I’ve tried manually calling them without success.

Any guidance on which part of the pipeline is missing or how to replicate model.transcribe() exactly would be very helpful. If there is a simpler way to integrate NeMo in Flutter also, please let me know

Thank you!

diarray-hub avatar Mar 27 '25 19:03 diarray-hub

@anyone, I would greatly appreciate any help on this

diarray-hub avatar Mar 31 '25 12:03 diarray-hub

@anyone, I would greatly appreciate any help on this

I can't believe @anyone is actually someone's username 😆

diarray-hub avatar Apr 01 '25 12:04 diarray-hub

Hi, This went unnoticed. Could you confirm if you still see the issue?

nithinraok avatar Apr 28 '25 15:04 nithinraok

Yes, the issue is still present and I really need to fix it. Thank you so much for answering @nithinraok, I thought this would go unoticed for years. I belive this transcription pipeline is missing a crucial step that model.transcribe() does carry but I can't figure out which step it is. I went through the ctc_model.py file (in which EncDecCTC, the class for QuartzNet is defined) and I went through the source code of many parent classes and imported such modules ctc_decoding.py, ASRTranscriptionMixin class. I was unable to find the missing step to fully reproduce the workflow of the .transcribe method.

You could try the expriment with any pretrained QuartzNet model or I could share my notebook with my finetuned version if you prefer. I need to understand this mismatch and any help would be greatly appreciated.

PS: As you might have read above, my goal is to embed my quartznet model on my flutter app, if you know about any simpler way to do that without having to rewrite all the codes for preprocessing and post processing please do share your approaches. But I'm aslo working on something that require me to use the forward method to infer three different types of ASR models so I need to fix the above issue anyway. (Basically I just want to reproduce the .transcribe method for now)

diarray-hub avatar Apr 29 '25 14:04 diarray-hub

Will have a look at it. Any reason why you are still using Quarztnet based models when there are latest parakeet based models: https://huggingface.co/collections/nvidia/parakeet-659711f49d1469e51546e021 ?

nithinraok avatar Apr 29 '25 15:04 nithinraok

Yes, two reasons actually size and simplicity. I also fine tuned some models from the parakeet family, I love them and I'll be finetuning more very soon. But for this specific application, since we want to deploy the model on smartphones that are expected to not be very smart actually, I wanted something simple and very small, so a convolutional architecture with character based decoding sounded good. I probably also got influenced by the Intro to ASR with NeMo notebook which was apparently updated since my last reading.

But thank you very much @nithinraok and let me know if you would like to test with my model and notebook.

diarray-hub avatar Apr 29 '25 17:04 diarray-hub

Hi @nithinraok,

Thanks again for your help so far! I ran a few more experiments to see whether this weird “logits vs. token-ID” discrepancy is QuartzNet-only or more general—and it turns out it shows up on Parakeet models too.

1) New test results

Model .transcribe() text Custom decode() text
QuartzNet fine tuned (CTC) ni so tun bɛ se ka gafe kalan sufɛ gafe kalan tun bɛ diya a ye nin sogo tun bɛ sekagɔ ka fakala sogo bɛk damarfen ka aa tumaw bɛ yemisiriyaka kay ɲɛ
Parakeet Hybrid CTC fine tuned ni so tun bɛ se ka gafe kalan sufɛ gafe kalan tun bɛ diya a ye ni so tun bɛ se ka gafe ka la sufɛ gafe kalan bɛ diya a ye
Parakeet Hybrid RNNT fine tuned ni so tun bɛ se ka gafe kalan sufɛ gafe kalan tun bɛ diya a ye ni so tun bɛ se ka gafe kalan sufɛ gafe kalan tu bɛ diya a ye
Parakeet CTC (0.6B BPE) original lot to be sac graraphy color sufigraphy color to be dia ie we wouldood to be sa car gffet cara sufred gffy got out to be j ie
  1. Across all these models, my custom pipeline (with the same model.forward(...) + ctc_decoder_predictions_tensor) always returns a y_sequence of discrete token IDs (shape [T]), whereas .transcribe(return_hypotheses=True) returns a continuous logits tensor (shape [T, vocab_size]).

  2. Even on audio the models were trained on, the two pipelines differ by a character or extra space, which breaks words—so it’s not just a random edge case. The audio I used for this test was actually in my train set for fine tuning. The language is bambara

  3. I double-checked that no spec-augment, dithering, channel-averaging, etc., are active in eval. The only thing I can’t explain is why the Hypothesis object from the official .transcribe() carries the raw logits (y_sequence), whereas my pipeline immediately collapses them to IDs.

2) Helper functions I’m using

import torch
import torchaudio
from typing import Tuple

def load_audio(audio_path: str, sample_rate = 16000) -> Tuple[torch.Tensor, torch.Tensor]:
    """
    Load a mono channel audio file and return the waveform and tensor length
    Args:
        audio_path (str): Path to the audio file.
    Returns:
        Tuple[torch.Tensor, torch.Tensor]: Waveform and sampledaudio length.
    """
    # ----- Audio preprocessing -----
    waveform, sr = torchaudio.load(audio_path)
    # Optional resampling
    if sample_rate is not None and sr != sample_rate:
        resampler = torchaudio.transforms.Resample(sr, sample_rate)
        waveform = resampler(waveform)
        sr = sample_rate
    length_tensor = torch.tensor([waveform.size(1)], dtype=torch.long)
    return waveform, length_tensor 
def decode_batch(
    forward_output: torch.Tensor,
    enc_len: torch.Tensor,
    asr_model,
    return_hypotheses: bool = False,
) -> list[str]:
    """
    Generic decoder wrapper for CTC / RNNT models.
    """
    # Determine which decoder to call
    if hasattr(asr_model, 'cur_decoder') and asr_model.cur_decoder == 'ctc':
        log_probs = asr_model.ctc_decoder(encoder_output=forward_output)
        hyps = asr_model.decoding.ctc_decoder_predictions_tensor(
            decoder_outputs=log_probs,
            decoder_lengths=enc_len,
            return_hypotheses=return_hypotheses,
        )
    elif hasattr(asr_model.decoding, 'ctc_decoder_predictions_tensor'):
        hyps = asr_model.decoding.ctc_decoder_predictions_tensor(
            decoder_outputs=forward_output,
            decoder_lengths=enc_len,
            return_hypotheses=return_hypotheses,
        )
    elif hasattr(asr_model.decoding, 'rnnt_decoder_predictions_tensor'):
        hyps = asr_model.decoding.rnnt_decoder_predictions_tensor(
            encoder_output=forward_output,
            encoded_lengths=enc_len,
            return_hypotheses=return_hypotheses,
        )
    else:
        raise RuntimeError("No supported decoder found")

    # Extract text
    if isinstance(hyps, list) and isinstance(hyps[0], str):
        return hyps
    # if we got Hypothesis objects
    if isinstance(hyps, tuple):
        best = hyps[0]
        return [h.text for h in best]
    return [h.text for h in hyps]

3) My questions

  1. Is there still any audio-level transformation (e.g. framing, windowing, extra trimming or norm) that .transcribe() does after loading but before model.forward(), which I’m missing? All my tests bypass that by feeding raw waveform into model.forward(...).

  2. Why does .transcribe()’s Hypothesis carry a logits tensor for y_sequence, whereas the decoding helper always collapses to ID sequences immediately? I would expect the two pipelines to invoke exactly the same CTC/RNNT decoder—and yet the shapes and values differ.

  3. I sometimes see transcription shifts when I run multiple inferences with the .decode function. Generally as shallow as an additional or missing character. Could there be stateful interference (e.g. some buffer or random seed not fully reset) between calls?

I’m really close to a faithful reproduction of model.transcribe(), but these last mysteries are blocking me. Any pointers on what else I can try—or if there’s an internal step in _transcribe_output_processing() (or elsewhere) that I haven’t exposed—would be hugely appreciated!

Thanks again for your time.

diarray-hub avatar May 11 '25 20:05 diarray-hub

Update after moving to NeMo 2.3.0 (stable)

Hey @nithinraok, diarray here with the results of further experiments. I upgraded mu nemo version to 2.3.0 and tried a few things again. I now understand the y_sequence being Token IDs or logits thing. It only happens with CTC decoders now and it is controlled by the parameter return_hypotheses. Even though all ASR model classes now return list of Hypothesis objects, the y_sequences contain Token IDs per defualt and setting return_hypotheses to True explicitly replace them with logits. But I found out in the nemo/collections/asr/models/ctc_models.py that they are just kind of stuffing the log_probs in the y_sequences and this happens only with CTC type decoder models so it actually doesn't help much understanding the disparency of text between the two transcription pipelines.

So the core issue remains:
model.transcribe() and a direct call to model.forward() → decoding.<DECODER_TYPE>_decoder_predictions_tensor() still give different final texts and token‑ID sequences – from tiny edits (one letter dropped) to large drifts depending on the model and the language of the entry audio– even when:

  • the model is in .eval()
  • no timestamps / confidence are requested
  • greedy CTC is the only strategy (so no external LM / KenLM)
  • audio is mono 16 kHz and fed as raw waveform

Example outputs

  • Bambara fine‑tuned QuartzNet
transcribe : ni so tun bɛ se ka gafe kalan sufɛ gafe kalan tun bɛ diya a ye
custom     : nin sogo tun bɛ sekagɔ ka fakala sogo  bɛk damarfen ka aa tumaw bɛ yemisiriyaka kay ɲɛ

The same pattern occurs with

  • parakeet_0.6B-CTC (English)
transcribe : well i don't wish to see it any more observed phoebe ...
custom     : well i don't wish to see it anymore observed phoebe ...

A tiny difference any more -> anymore in text that is the result of almost 10 mismatches in the token IDs returned as y_sequence

  • parakeet_110_tdT-CTC Hybrid (CTC branch and RNNT branch | Bambara)
  • RNNT decoder
transcribe : ni so tun bɛ se ka gafe kalan sufɛ gafe kalan tun bɛ diya a ye
custom     : ni so tun bɛ se ka gafe kalan sufɛ gafe kalan tu bɛ diya a ye

Again a tiny shit tun -> tu. Strangely these transcriptions  happen to change between runs with the parakeet models
  • CTC decoder
transcribe : ni so tun bɛ se ka gafe kalan sufɛ gafe kalan tun bɛ diya a ye
custom     : ni so tun bɛ se ka gafe ka la sufɛ gafe kalan bɛ diya a ye

A little more noticeable

When the parakeet models model are evaluated on audio it was trained on, the gap shrinks to 1–3 characters sometimes, but it is still present. Even though the output of the custom pipeline for Quartznet is more consistent (don't change with multiple runs), the gap is way more significant.


Questions

  1. Where does model.transcribe() alter the logits before they hit the decoder?
  2. Is there any other normalisation / padding / length clamp that transcribe() applies but a raw model.forward() doesn’t?
  3. Why do the output of model.transcribe and my pipeline happen to change with multiple runs of parakeet models and with QuartzNet they don't ?

If someone could point to the exact code path or flag I’m missing, that would help A LOT – I’m trying to implement a RLHF approach to ASR so any hidden transforms break parity.

Thanks again for your time!


@istupakov, I have seen that you successfully reimplemented the preprocessing and post-processing of some parakeet models in onnx-asr and judging by the results in your HF space for onnx-asr you were very successful. Can you help please ?

diarray-hub avatar May 13 '25 18:05 diarray-hub

Hi @diarray-hub

I'm not exactly an expert in the Nemo framework. I've done a lot of research into the preprocessing in Nemo for my library, but I based the decoding on papers describing CTC/RNN-T/TDT, not on the Nemo source code.

Regarding your issue, I debugged it a bit now and I think it occurs because transcribe() resets model.eval() after itself!

Try to check this code:

import torch
import torchaudio
import nemo.collections.asr as nemo_asr

model = nemo_asr.models.ASRModel.from_pretrained("nvidia/parakeet-ctc-0.6b")

waveform, sample_rate = torchaudio.load('2086-149220-0033.wav')
result = model.transcribe(waveform[0])

model.eval()
logits, lens, tokens = model.forward(input_signal=waveform, input_signal_length=torch.tensor([waveform.shape[-1]]))

result[0].y_sequence == tokens

istupakov avatar May 14 '25 03:05 istupakov

Thank you so much @istupakov! You just put an end to a week of struggle where I was really trying to solve this problem.

I took a close look at the nemo.collections.asr.parts.mixins.transcription module and you were right – ASRTranscriptionMixin freezes the encoder/decoder/joint at the begining of every model.transcribe call and explicitly unfreezes at the end of transcription, putting those specific modules in in train() mode on exit (see lines 742-780).

Because my notebook called model.transcribe() first, those blocks were left in training mode, so I think dropout and BatchNorm skewed the logits on my next forward pass. That explain the disparency and why smaller models like QuartzNet were suffering the most, with less parameters and less training I guess they are logically more sensitive to that. Especially QuartzNet has a dropout layer at the end of every Jasper Block.

Adding model.eval() (or freezing the blocks) before another forward pass makes the outputs match byte-for-byte. But it stills let me wondering why, I have a theory that the .transcribe() might be called in validation loops to they partially unfreeze the models (encoder and decoder) to make sure the next backward pass out of validation loop does not fail but why doing that in the .transcribe pipeline that they show as an inference API ?

Anyway, thank you very much @istupakov, you are my savior :)

diarray-hub avatar May 15 '25 16:05 diarray-hub