Transformers-Tutorials icon indicating copy to clipboard operation
Transformers-Tutorials copied to clipboard

Can `VisionEncoderDecoderModel.generate()` work with batched data?

Open plamb-viso opened this issue 2 years ago • 2 comments

Sorry if this is the wrong place to post this.

I'm currently trying to finetune Donut using your excellent fine-tuning guide as a starting point. As a test, I am calling VisionEncoderDecoderModel.generate() like so:

outputs = model_generator.generate(
                batch['pixel_values'],
                decoder_input_ids=decoder_input_ids,
                max_length=model_generator.decoder.config.max_position_embeddings,
                early_stopping=True,
                pad_token_id=processor.tokenizer.pad_token_id,
                eos_token_id=processor.tokenizer.eos_token_id,
                use_cache=False,
                num_beams=1,
                bad_words_ids=[[processor.tokenizer.unk_token_id]],
                return_dict_in_generate=True,
            )

batch['pixel_values'] has a shape like so: pixel_values_shape=torch.Size([2, 3, 500, 647])

That is, two page images are passed to .generate(). As such, I was expecting outputs.sequences to have the starting dimension two but it is always one: outputs=torch.Size([1, 5]) output=tensor([57552, 57550, 57526, 57551, 2], device='cuda:3')

I realize I can iterate the batch and generate each sequence independently, but I'd be amazed if .generate() can't handle batch data like this. Is there something I'm missing?

plamb-viso avatar Oct 17 '23 21:10 plamb-viso

Hi,

Yes theoretically the VisionEncoderDecoderModel class should support batched generation. Could you open an issue regarding this on the Transformers (if there is no such issue yet)?

NielsRogge avatar Oct 18 '23 07:10 NielsRogge

def ocr_image(src_img): image_list = [] for image_path in batch_image: image = Image.open(image_path) image_list.append(image) pixel_values = processor(images=image_list, return_tensors="pt") generated_ids = model.generate(**pixel_values) print(processor.batch_decode(generated_ids, skip_special_tokens=True))

this code will work for image path

janakiram180 avatar Apr 15 '24 10:04 janakiram180