[WIP] Refactor Deberta/Deberta-v2
What does this PR do?
Refactor both Deberta and DebertaV2 to make them more compatible with the overall transformers library
Should fix a bunch of issues related to torch-scripting with Deberta:
- #15216
- #15673
- #16456
- #18659
- #21300
- #20815
- #12436
- #18674
- help supporting the Prefix_Tuning PEFT approach
The docs for this PR live here. All of your documentation changes will be reflected on that endpoint.
Hey @ArthurZucker any updates on this? ETA for when it will be merged into main?
Hey! Just got back from holidays, this should be my main focus in the coming days!
Sorry! Seem like I had to postpone this! If anyone want to take over feel free to do it, otherwise will be my priority once https://github.com/huggingface/transformers/pull/23909 is merge!
Regarding the z_steps in DebertaV2Model: I think this code is relevant for the enhanced mask decoder of the generator model
if attention_mask.dim() <= 2:
extended_attention_mask = attention_mask.unsqueeze(1).unsqueeze(2)
att_mask = extended_attention_mask.byte()
attention_mask = att_mask * att_mask.squeeze(-2).unsqueeze(-1)
elif attention_mask.dim() == 3:
attention_mask = attention_mask.unsqueeze(1)
target_mask = target_ids > 0
hidden_states = encoder_layers[-2]
if not self.position_biased_input:
layers = [encoder.layer[-1] for _ in range(2)]
z_states += hidden_states
query_states = z_states
query_mask = attention_mask
outputs = []
rel_embeddings = encoder.get_rel_embedding()
for layer in layers:
# TODO: pass relative pos ids
output = layer(hidden_states, query_mask, return_att=False, query_states=query_states,
relative_pos=relative_pos, rel_embeddings=rel_embeddings)
query_states = output
outputs.append(query_states)
else:
outputs = [encoder_layers[-1]]
As far as I can tell, they hardcoded z_steps to 2 here. Although it should still be left as 0 for the discriminator. Adding the z_steps to the config seems like a good idea.
z_states represents the position embeddings, which are non-zero if position_biased_input is set to True. They are passed from the embedding layer. So in order to properly implement this, I think we need to return the position embeddings here:
class DebertaV2Embeddings(nn.Module):
def forward(self, input_ids=None, token_type_ids=None, position_ids=None, mask=None, inputs_embeds=None):
...
return embeddings, position_embeddings
and implement the z_steps like this:
class DebertaV2Model(DebertaV2PreTrainedModel):
def forward(
self,
input_ids: Optional[torch.Tensor] = None,
attention_mask: Optional[torch.Tensor] = None,
token_type_ids: Optional[torch.Tensor] = None,
position_ids: Optional[torch.Tensor] = None,
inputs_embeds: Optional[torch.Tensor] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
) -> Union[Tuple, BaseModelOutput]:
...
embedding_output, position_embedding_output = self.embeddings(
input_ids=input_ids,
token_type_ids=token_type_ids,
position_ids=position_ids,
mask=attention_mask,
inputs_embeds=inputs_embeds,
)
...
if self.z_steps > 0:
hidden_states = encoded_layers[-2]
layers = [self.encoder.layer[-1] for _ in range(self.z_steps)]
position_embedding_output += hidden_states
query_states = position_embedding_output
query_mask = self.encoder.get_attention_mask(attention_mask)
rel_embeddings = self.encoder.get_rel_embedding()
rel_pos = self.encoder.get_rel_pos(embedding_output)
for layer in layers:
query_states = layer(
hidden_states,
query_mask,
output_attentions=False,
query_states=query_states,
relative_pos=rel_pos,
rel_embeddings=rel_embeddings,
)
encoded_layers = encoded_layers + (query_states,)
What is the status? The logs of the checks are expired.
#27734 should help with some of the issues in the mean time