transformers icon indicating copy to clipboard operation
transformers copied to clipboard

[WIP] Refactor Deberta/Deberta-v2

Open ArthurZucker opened this issue 3 years ago • 1 comments

What does this PR do?

Refactor both Deberta and DebertaV2 to make them more compatible with the overall transformers library

Should fix a bunch of issues related to torch-scripting with Deberta:

  • #15216
  • #15673
  • #16456
  • #18659
  • #21300
  • #20815
  • #12436
  • #18674
  • help supporting the Prefix_Tuning PEFT approach

ArthurZucker avatar Mar 11 '23 10:03 ArthurZucker

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint.

Hey @ArthurZucker any updates on this? ETA for when it will be merged into main?

hriaz17 avatar May 22 '23 21:05 hriaz17

Hey! Just got back from holidays, this should be my main focus in the coming days!

ArthurZucker avatar May 23 '23 09:05 ArthurZucker

Sorry! Seem like I had to postpone this! If anyone want to take over feel free to do it, otherwise will be my priority once https://github.com/huggingface/transformers/pull/23909 is merge!

ArthurZucker avatar Jun 19 '23 10:06 ArthurZucker

Regarding the z_steps in DebertaV2Model: I think this code is relevant for the enhanced mask decoder of the generator model

if attention_mask.dim() <= 2:
    extended_attention_mask = attention_mask.unsqueeze(1).unsqueeze(2)
    att_mask = extended_attention_mask.byte()
    attention_mask = att_mask * att_mask.squeeze(-2).unsqueeze(-1)
elif attention_mask.dim() == 3:
    attention_mask = attention_mask.unsqueeze(1)
target_mask = target_ids > 0
hidden_states = encoder_layers[-2]
if not self.position_biased_input:
    layers = [encoder.layer[-1] for _ in range(2)]
    z_states += hidden_states
    query_states = z_states
    query_mask = attention_mask
    outputs = []
    rel_embeddings = encoder.get_rel_embedding()

    for layer in layers:
        # TODO: pass relative pos ids
        output = layer(hidden_states, query_mask, return_att=False, query_states=query_states,
                       relative_pos=relative_pos, rel_embeddings=rel_embeddings)
        query_states = output
        outputs.append(query_states)
else:
    outputs = [encoder_layers[-1]]

As far as I can tell, they hardcoded z_steps to 2 here. Although it should still be left as 0 for the discriminator. Adding the z_steps to the config seems like a good idea.

z_states represents the position embeddings, which are non-zero if position_biased_input is set to True. They are passed from the embedding layer. So in order to properly implement this, I think we need to return the position embeddings here:

class DebertaV2Embeddings(nn.Module):
    def forward(self, input_ids=None, token_type_ids=None, position_ids=None, mask=None, inputs_embeds=None):
        ...

        return embeddings, position_embeddings

and implement the z_steps like this:

class DebertaV2Model(DebertaV2PreTrainedModel):
    def forward(
        self,
        input_ids: Optional[torch.Tensor] = None,
        attention_mask: Optional[torch.Tensor] = None,
        token_type_ids: Optional[torch.Tensor] = None,
        position_ids: Optional[torch.Tensor] = None,
        inputs_embeds: Optional[torch.Tensor] = None,
        output_attentions: Optional[bool] = None,
        output_hidden_states: Optional[bool] = None,
        return_dict: Optional[bool] = None,
    ) -> Union[Tuple, BaseModelOutput]:
        ...

        embedding_output, position_embedding_output = self.embeddings(
            input_ids=input_ids,
            token_type_ids=token_type_ids,
            position_ids=position_ids,
            mask=attention_mask,
            inputs_embeds=inputs_embeds,
        )
        ...

        if self.z_steps > 0:
            hidden_states = encoded_layers[-2]
            layers = [self.encoder.layer[-1] for _ in range(self.z_steps)]
            position_embedding_output += hidden_states
            query_states = position_embedding_output
            query_mask = self.encoder.get_attention_mask(attention_mask)
            rel_embeddings = self.encoder.get_rel_embedding()
            rel_pos = self.encoder.get_rel_pos(embedding_output)
            for layer in layers:
                query_states = layer(
                    hidden_states,
                    query_mask,
                    output_attentions=False,
                    query_states=query_states,
                    relative_pos=rel_pos,
                    rel_embeddings=rel_embeddings,
                )
                encoded_layers = encoded_layers + (query_states,)

zynaa avatar Jun 29 '23 00:06 zynaa

What is the status? The logs of the checks are expired.

Bachstelze avatar Feb 08 '24 13:02 Bachstelze

#27734 should help with some of the issues in the mean time

ArthurZucker avatar Feb 13 '24 03:02 ArthurZucker