Correct Whisper's beam search scores computation
Fixes #32246
There have been many failing tests these past days with Whisper, so I'd probably wait for them to be fixed before merging this PR.
What does this PR do?
@cifkao made a great summary the current issue in #32246:
TL;DR: Scores corresponding to the wrong sequence in the batch/beam are returned.
He also rightfully identified what was the origin of the issue:
The bug seems to be here in _postprocess_outputs. This works fine with num_beams==1, but with num_beams>1, the shape of the items in seek_outputs["scores"] will be [num_beams * batch_size, vocab_size], while the code expects it to be [batch_size, vocab_size]. Therefore, instead of choosing the correct sequence in the beam/batch, this code will incorrectly combine scores from different sequences.
The solution simply consists in taking the right logits_scores for each of the generated tokens.
Instead of taking the batch_idx-th logits_scores out of the num_beams * batch, we're now taking the beam_idx-th logits_scores.
Reproduction results
I've recomputed the code snippet from #32246.
How to read the results:
The first set of scores are the scores corresponding to each generated tokens, as well as their beam index. The second set of scores are the scores of a handmade forward pass of the generated tokens, they indicates the "true scores" that we should have.
Notice how in #32246, the scores coming from the 1-th beam index are different from the recomputed scores. It indicates that we selected the wrong scores.
Here, they're about the same, which indicates we selected the right beam indices.
Scores out of the generation:
('<|0.00|>', -0.06171704828739166, 0)
(' Folks', -1.9032700061798096, 0)
(',', -0.40583235025405884, 0)
(' if', -0.03763910010457039, 0)
(' you', -0.0019693044014275074, 0)
(' watch', -0.14575302600860596, 0)
(' the', -0.2036631554365158, 0)
(' show', -0.002341626212000847, 0)
(',', -0.2806797921657562, 0)
(' you', -0.290231853723526, 0)
(' know', -0.025554247200489044, 0)
(' I', -1.0598242282867432, 0)
(' spent', -0.5059170722961426, 1)
(' a', -0.02328178472816944, 1)
(' lot', -0.02414931170642376, 1)
(' of', -0.02351410686969757, 1)
(' time', -0.015122056938707829, 1)
(' right', -1.1174389123916626, 1)
(' over', -0.020583242177963257, 1)
(' there', -0.031000398099422455, 0)
('.', -0.23914632201194763, 0)
('<|5.12|>', -3.7109971046447754, 0)
Scores out of the forward:
('<|en|>', -0.3857421875)
('<|transcribe|>', -6.556510925292969e-06)
('<|0.00|>', -0.1939697265625)
(' Folks', -1.931640625)
(',', -0.40966796875)
(' if', -0.0380859375)
(' you', -0.002063751220703125)
(' watch', -0.1456298828125)
(' the', -0.2041015625)
(' show', -0.00235748291015625)
(',', -0.283447265625)
(' you', -0.2978515625)
(' know', -0.0259857177734375)
(' I', -1.0849609375)
(' spent', -0.499755859375)
(' a', -0.0235137939453125)
(' lot', -0.02398681640625)
(' of', -0.0230560302734375)
(' time', -0.0152587890625)
(' right', -1.12109375)
(' over', -0.0210113525390625)
(' there', -0.030975341796875)
('.', -0.2425537109375)
('<|5.12|>', -3.802734375)
...
** Code:**
from datasets import Audio, load_dataset
from transformers import WhisperForConditionalGeneration, AutoProcessor
import torch
import numpy as np
model = WhisperForConditionalGeneration.from_pretrained(
"openai/whisper-tiny", torch_dtype=torch.float16
)
processor = AutoProcessor.from_pretrained("openai/whisper-tiny")
model.cuda()
ds = load_dataset("distil-whisper/meanwhile", "default")["test"]
ds = ds.cast_column("audio", Audio(sampling_rate=16000))
audio = ds[0]["audio"]["array"].astype(np.float32)
inputs = processor(
[audio],
return_tensors="pt",
truncation=False,
padding="longest",
sampling_rate=16_000,
)
inputs = inputs.to(model.device, torch.float16)
generation_output = model.generate(
**inputs,
language="en",
return_timestamps=True,
return_segments=True,
output_scores=True,
num_beams=2,
# num_return_sequences=1,
temperature=0.0,
logprob_threshold=0.0,
compression_ratio_threshold=2.4,
no_speech_threshold=0.6,
)
# Print each token along with its log-probability and beam index
segment = generation_output["segments"][0][0]
tokens = segment["result"]["sequences"]
scores = segment["result"]["scores"]
beam_indices = segment["result"]["beam_indices"]
logprobs = torch.as_tensor([s.float().log_softmax(-1)[t] for s, t in zip(scores, segment["tokens"])])
print(*[(processor.tokenizer.decode([t], decode_with_timestamps=True), s.item(), b.item()) for s, t, b in zip(logprobs, tokens, beam_indices)], sep="\n")
# Now run a forward pass with the generated tokens
inputs_forward = {k: v[..., :3000].cuda() for k, v in inputs.items()}
inputs_forward["decoder_input_ids"] = torch.cat(
[
torch.as_tensor(processor.tokenizer.encode("<|startoftranscript|><|en|><|transcribe|>", add_special_tokens=False)),
tokens,
],
)[None].cuda()
with torch.inference_mode():
output_forward = model(**inputs_forward)
# Print each token along with its log-probability
print(*[(processor.tokenizer.decode([t], decode_with_timestamps=True), s[t].item()) for s, t in zip(torch.nn.functional.log_softmax(
output_forward.logits.squeeze(0), dim=-1
), inputs_forward["decoder_input_ids"].squeeze(0)[1:])], sep="\n")
cc @LysandreJik, @kamilakesbi and @sanchit-gandhi
The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.
Nice, that looks like the correct fix to me!
I suspect that the other items (attentions, hidden states, logits) will have size num_beams * batch_size too though, so they might require indexing by beam_idx instead of batch_idx as well?
(Also, for anyone wondering why the scores are not exactly the same, it's likely because of the logits processors SuppressTokensLogitsProcessor and WhisperTimeStampLogitsProcessor, which suppress certain tokens.)
Merging now