FLAVA not doing a forward pass
System Info
-
transformersversion: 4.27.0.dev0 - Platform: Linux-5.15.90.1-microsoft-standard-WSL2-x86_64-with-glibc2.35
- Python version: 3.10.6
- Huggingface_hub version: 0.12.0
- PyTorch version (GPU?): not installed (NA)
- Tensorflow version (GPU?): not installed (NA)
- Flax version (CPU?/GPU?/TPU?): not installed (NA)
- Jax version: not installed
- JaxLib version: not installed
- Using GPU in script?:
- Using distributed or parallel set-up in script?:
N.B. I do have PyTorch installed, I'm not sure why the tool can't find it:
python -c "import torch; print(torch.__version__)"
2.1.0.dev20230310
Who can help?
@apsdehal
Information
- [X] The official example scripts
- [ ] My own modified scripts
Tasks
- [X] An officially supported task in the
examplesfolder (such as GLUE/SQuAD, ...) - [ ] My own task or dataset (give details below)
Reproduction
Steps to reproduce the behavior (also a Colab notebook doing this):
- Get a datapoint for a forward pass (
fetch_imagesis in the notebook above):
pmd = datasets.load_dataset("facebook/pmd", "wit", use_auth_token=True, streaming=True)
pmd_train_head = pmd['train'].take(2)
pmd_train_head_with_images = pmd_train_head.map(fetch_images, batched=True, batch_size=100, fn_kwargs={"num_threads": 20})
datapoint = next(iter(pmd_train_head_with_images))
- Process the input:
from transformers import FlavaProcessor, FlavaForPreTraining
processor = FlavaProcessor.from_pretrained("facebook/flava-full")
inputs = processor(
text=[datapoint['text']],
images=[datapoint['image']],
return_tensors="pt",
padding="max_length",
max_length=77,
return_codebook_pixels=True,
return_image_mask=True,
return_attention_mask=True,
return_token_type_ids=True,
return_special_tokens_mask=True,
)
inputs.bool_masked_pos
- Mask the text input for MLM:
from transformers import DataCollatorForLanguageModeling, AutoTokenizer
data_collator = DataCollatorForLanguageModeling(processor.tokenizer, mlm=True, mlm_probability=0.4, return_tensors="pt")
inputs['input_ids'], inputs['input_ids_masked'] = data_collator.torch_mask_tokens(inputs=inputs['input_ids'],
special_tokens_mask=inputs['special_tokens_mask'])
del inputs['special_tokens_mask']
- Do a forward pass:
model = FlavaForPreTraining.from_pretrained("facebook/flava-full")
outputs = model(**inputs)
loss = outputs.loss
print(f"loss: {loss}")
Expected behavior
I would expect the forward pass to not throw errors.
Actual behavior
---------------------------------------------------------------------------
IndexError Traceback (most recent call last)
<ipython-input-14-b821d73f49e9> in <module>
1 model = FlavaForPreTraining.from_pretrained("facebook/flava-full")
2
----> 3 outputs = model(**inputs)
---------------------------------------------------------------------------
/usr/local/lib/python3.9/dist-packages/torch/nn/modules/module.py in _call_impl(self, *input, **kwargs)
1192 if not (self._backward_hooks or self._forward_hooks or self._forward_pre_hooks or _global_backward_hooks
1193 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1194 return forward_call(*input, **kwargs)
1195 # Do not call functions when jit is used
1196 full_backward_hooks, non_full_backward_hooks = [], []
/usr/local/lib/python3.9/dist-packages/transformers/models/flava/modeling_flava.py in forward(self, input_ids, input_ids_masked, pixel_values, codebook_pixel_values, attention_mask, token_type_ids, bool_masked_pos, position_ids, image_attention_mask, skip_unmasked_multimodal_encoder, mlm_labels, mim_labels, itm_labels, output_attentions, output_hidden_states, return_dict, return_loss)
1857 )
1858
-> 1859 flava_masked_output = self.flava(
1860 input_ids=input_ids_masked,
1861 pixel_values=pixel_values,
/usr/local/lib/python3.9/dist-packages/torch/nn/modules/module.py in _call_impl(self, *input, **kwargs)
1192 if not (self._backward_hooks or self._forward_hooks or self._forward_pre_hooks or _global_backward_hooks
1193 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1194 return forward_call(*input, **kwargs)
1195 # Do not call functions when jit is used
1196 full_backward_hooks, non_full_backward_hooks = [], []
/usr/local/lib/python3.9/dist-packages/transformers/models/flava/modeling_flava.py in forward(self, input_ids, pixel_values, attention_mask, token_type_ids, bool_masked_pos, position_ids, image_attention_mask, skip_multimodal_encoder, output_attentions, output_hidden_states, return_dict)
1403 text_output = None
1404 if input_ids is not None:
-> 1405 text_output = self.text_model(
1406 input_ids=input_ids,
1407 attention_mask=attention_mask,
/usr/local/lib/python3.9/dist-packages/torch/nn/modules/module.py in _call_impl(self, *input, **kwargs)
1192 if not (self._backward_hooks or self._forward_hooks or self._forward_pre_hooks or _global_backward_hooks
1193 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1194 return forward_call(*input, **kwargs)
1195 # Do not call functions when jit is used
1196 full_backward_hooks, non_full_backward_hooks = [], []
/usr/local/lib/python3.9/dist-packages/transformers/models/flava/modeling_flava.py in forward(self, input_ids, attention_mask, token_type_ids, position_ids, head_mask, output_attentions, output_hidden_states, return_dict)
1061 )
1062
-> 1063 embedding_output = self.embeddings(
1064 input_ids=input_ids,
1065 token_type_ids=token_type_ids,
/usr/local/lib/python3.9/dist-packages/torch/nn/modules/module.py in _call_impl(self, *input, **kwargs)
1192 if not (self._backward_hooks or self._forward_hooks or self._forward_pre_hooks or _global_backward_hooks
1193 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1194 return forward_call(*input, **kwargs)
1195 # Do not call functions when jit is used
1196 full_backward_hooks, non_full_backward_hooks = [], []
/usr/local/lib/python3.9/dist-packages/transformers/models/flava/modeling_flava.py in forward(self, input_ids, token_type_ids, position_ids)
417 token_type_ids = torch.zeros(input_shape, dtype=torch.long, device=self.position_ids.device)
418
--> 419 inputs_embeds = self.word_embeddings(input_ids)
420 token_type_embeddings = self.token_type_embeddings(token_type_ids)
421
/usr/local/lib/python3.9/dist-packages/torch/nn/modules/module.py in _call_impl(self, *input, **kwargs)
1192 if not (self._backward_hooks or self._forward_hooks or self._forward_pre_hooks or _global_backward_hooks
1193 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1194 return forward_call(*input, **kwargs)
1195 # Do not call functions when jit is used
1196 full_backward_hooks, non_full_backward_hooks = [], []
/usr/local/lib/python3.9/dist-packages/torch/nn/modules/sparse.py in forward(self, input)
158
159 def forward(self, input: Tensor) -> Tensor:
--> 160 return F.embedding(
161 input, self.weight, self.padding_idx, self.max_norm,
162 self.norm_type, self.scale_grad_by_freq, self.sparse)
/usr/local/lib/python3.9/dist-packages/torch/nn/functional.py in embedding(input, weight, padding_idx, max_norm, norm_type, scale_grad_by_freq, sparse)
2208 # remove once script supports set_grad_enabled
2209 _no_grad_embedding_renorm_(weight, input, max_norm, norm_type)
-> 2210 return torch.embedding(weight, input, padding_idx, scale_grad_by_freq, sparse)
2211
2212
IndexError: index out of range in self
Hi @amariucaitheodor . Thank you for reporting the issue!
Could you also copy-paste the error (traceback) you got to your above PR description? Thanks.
I tried the colab and found the issue. Specifically, the code which is used for calculating input_ids and input_ids_masked is incorrect as the torch_mask_tokens function returns modified input_ids with masking and the corresponding labels. Since the loss is only calculated on the masked tokens, other tokens are set to -100 in the labels. This causes an "index out of range" error down the line in the embeddings' forward.
Thank you for the reply! I had noticed the same problem.
What is then the correct way of calculating input_ids_masked? The code doesn't work with DataCollatorForLanguageModeling for the reasons mentioned above, and there is no other example for doing this.
Thank you @amariucaitheodor for providing the error log, and thanks @apsdehal for sharing your finding. I will take a look on this issue. But @apsdehal , don't hesitate to share if you have any idea regarding the correct solution ❤️
Hello! After looking into the issue with the notebook, here is my finding:
-
data_collator.torch_mask_tokens(inputs=inputs['input_ids'], ....)return two items- the first item is the input ids being masked
- the second item indicates:
- if a place has value
-100: it means that places is not masked - otherwise, it gives the original value of that place in
inputs
- if a place has value
- The
FlavaForPreTrainingmodel expectinput_ids_maskedto be the masked inputs, which is the first item prepared above. See https://github.com/huggingface/transformers/blob/f7329751fe5c43365751951502c00df5a4654359/src/transformers/models/flava/modeling_flava.py#L803-L805 - However, in the notebook, you do
which causeinputs['input_ids'], inputs['input_ids_masked'] = data_collator.torch_mask_tokens(...)inputs['input_ids_masked']to be the 2nd item return ed bytorch_mask_tokenswhich is incorrect. In particularly, it contains-100, which causes the error. Furthermore,inputs['input_ids']is also the wrong value, but it doesn't cause the program to crash.
The solution is just to prepare the correct inputs for the model:
inputs['input_ids_masked'], _ = data_collator.torch_mask_tokens(
inputs=inputs['input_ids'],
special_tokens_mask=inputs['special_tokens_mask']
)
With this change, I get loss: 7.162976264953613.
Let me know if you have further question 🤗
@ydshieh I don't think this is also correct as torch_mask_tokens masks the input_ids in place so you will have to clone the input_ids before passing them to it.
@apsdehal Thanks a lot, nice catch! You are 100% correct. @amariucaitheodor Please see this comment too!
As it turns out that this is not an issue in modeling code in transformers, but the wrong preparation of model inputs, I move forward to close the issue.
@amariucaitheodor If you still have issues, you can post on Hugging Face Forums.
However, if you find other issue(s) you believe that is/are in modeling code, feel free to continue to leave comments here.