transformers icon indicating copy to clipboard operation
transformers copied to clipboard

[WIP] - Support generating with fallback for short form audio in Whisper

Open kamilakesbi opened this issue 1 year ago • 1 comments

What does this PR do?

The aim of this PR is to refacto the Whisper generate method to handle both short form and long form audio generation similarly. It will support short form audio generation with fallback (as requested in #29508).

I've been working on a first draft of what it would look like. Here's what I've done for now:

  • Removed the part of the code used for short form generation. Now when a short form of audio (or a batched short form of audio) is passed to generate it is processed by the part of the code previously used for long form generation.

  • I still use a is_shortform parameter to distinguish between short form and long form audios. I've adapted parts of the code where we need to use this parameter:

--> _retrieve_max_frames_and_seek needs to be adapted: if we are processing batched short form audios, we don't necessarily need the attention_masks.

--> In the short form generation, the start and end of each sequence contains the decoder_input_ids and eos tokens. I've made sure this is still the case with the new generate method.

--> I made sure we can still do short form generation when generation_config.no_timestamps_token_id is not defined.

--> I made sure we can still do short form generation when logits_processor is None.

  • I've also adapted the code to make it compatible with return_tokens_timestamps=True.

I run the following snippet and compare the output I get with the old and new generate methods:

from datasets import load_dataset
from transformers import WhisperForConditionalGeneration, AutoProcessor
import torch

processor = AutoProcessor.from_pretrained("openai/whisper-tiny.en")
model = WhisperForConditionalGeneration.from_pretrained("openai/whisper-tiny.en", torch_dtype=torch.float16)
model = model.to("cuda")

# Batched short form audios: 
dataset = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation[:2]")
ds = dataset.select(range(2))[:2]['audio']
audios = [x["array"] for x in ds]
inputs = processor(audios, return_tensors="pt", truncation=False).to("cuda", torch.float16)

result = model.generate(**inputs, return_timestamps=False)

return_timestamps=False and return_timestamps=True and return_tokens_timestamps=True will give the same outputs.

Next steps:

  • We will get errors if num_return_sequences>1.

Who can review:

@sanchit-gandhi

kamilakesbi avatar May 23 '24 10:05 kamilakesbi

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

I have run test_modeling_whisper slow tests. Here are the currently failing tests (which pass on main):

  • test_assisted_decoding_matches_greedy_search_0_random
  • test_speculative_decoding_non_distil
  • test_generate_continue_from_past_key_values
  • test_tiny_en_generation
  • test_tiny_generation
  • test_tiny_en_batched_generation
  • test_tiny_timestamp_generation
  • test_tiny_token_timestamp_batch_generation
  • test_large_timestamp_generation

kamilakesbi avatar Jun 04 '24 17:06 kamilakesbi

@sanchit-gandhi I have taken all your reviews into account. The CI and slow test pass :) This PR should be ready for final review!

cc @amyeroberts.

kamilakesbi avatar Jun 07 '24 14:06 kamilakesbi

Added two slow tests (test_whisper_shortform_single_batch_prev_cond and test_whisper_shortform_multi_batch_hard_prev_cond) and made a few modifications to make them pass.

kamilakesbi avatar Jun 07 '24 17:06 kamilakesbi

Hi @ArthurZucker, could you please do a review on this PR ? Failing tests on the CI are unrelated to this work I think :)

kamilakesbi avatar Jun 11 '24 16:06 kamilakesbi

@kamilakesbi As this is quite big PR - is it possible to split up? In particular, could we split up the enabling of speculative decoding and the unification of the long and short form audio logic?

amyeroberts avatar Jun 12 '24 13:06 amyeroberts

Hi @amy, in this PR we only focus on the unification of the long and short form audio logic! we don't change how speculative decoding is handled here :)

kamilakesbi avatar Jun 12 '24 13:06 kamilakesbi

I've adapted the code to make it compatible with speculative decoding, return_tokens_timestamps, return_timestamps, return_dict_in_generate, num_beams>=1 and num_return_sequences>=1, max_new_tokens and max_length parameters.

@kamilakesbi Could you update the PR description to reflect the current state (either disambiguate for confused people like me or remove :) )?

amyeroberts avatar Jun 12 '24 14:06 amyeroberts

@amyeroberts, sorry for that. We've indeed worked on many points with Sanchit to make this PR pass. I've updated the PR description to show the main changes in the current state. Hope this makes it clearer for you!

kamilakesbi avatar Jun 12 '24 17:06 kamilakesbi

I'll take care of reviewing it in a bit!

ArthurZucker avatar Jun 19 '24 07:06 ArthurZucker

This issue has been automatically marked as stale because it has not had recent activity. If you think this still needs to be addressed please comment on this thread.

Please note that issues that do not follow the contributing guidelines are likely to be ignored.

github-actions[bot] avatar Jul 13 '24 08:07 github-actions[bot]

Gentle ping @ArthurZucker

kamilakesbi avatar Jul 16 '24 10:07 kamilakesbi

@ArthurZucker thanks for your review! I took your remarks into account :)

Failing tests are unrelated to this PR. If this is ok for you we can perhaps merge or wait for the CI to be green...

kamilakesbi avatar Jul 17 '24 09:07 kamilakesbi

Let's wait for the full CI seems alright now!

ArthurZucker avatar Jul 18 '24 12:07 ArthurZucker

Also a question ont answered!

ArthurZucker avatar Jul 18 '24 12:07 ArthurZucker

The CI is green yes :) if it's ok for you I can merge!

kamilakesbi avatar Jul 18 '24 12:07 kamilakesbi