learning-subspaces icon indicating copy to clipboard operation
learning-subspaces copied to clipboard

Simple variants for `nn.Embedding` and `nn.LSTM`

Open vaseline555 opened this issue 4 years ago • 4 comments

Hi, I really enjoyed your work and I am the first one who cited your work! (https://arxiv.org/abs/2109.07628)

As I made some simple variants of your method for language models, could you please check this out? I think it is not much to be a PR, so I made an issue instead for the request of simple checks from original authors.

Could you please check? Thank you in advance.

Best, Adam


# LSTM layer
class SubspaceLSTM(nn.LSTM):
    def forward(self, x):
        # call get_weight, which samples from the subspace, then use the corresponding weight.
        weight_dict = self.get_weight()
        mixed_lstm = nn.LSTM(
            input_size=self.input_size, 
            hidden_size=self.hidden_size, 
            num_layers=self.num_layers, 
            batch_first=self.batch_first
        )
        for l in range(self.num_layers):
            setattr(mixed_lstm, f'weight_hh_l{l}', nn.Parameter(weight_dict[f'weight_hh_l{l}_mixed']))
            setattr(mixed_lstm, f'weight_ih_l{l}', nn.Parameter(weight_dict[f'weight_ih_l{l}_mixed']))
            if self.bias:
                setattr(mixed_lstm, f'bias_hh_l{l}', nn.Parameter(weight_dict[f'bias_hh_l{l}_mixed']))
                setattr(mixed_lstm, f'bias_ih_l{l}', nn.Parameter(weight_dict[f'bias_ih_l{l}_mixed']))
        return mixed_lstm(x)

class TwoParamLSTM(SubspaceLSTM):
    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)
        for l in range(self.num_layers):
            setattr(self, f'weight_hh_l{l}_1', nn.Parameter(torch.zeros_like(getattr(self, f'weight_hh_l{l}'))))
            setattr(self, f'weight_ih_l{l}_1', nn.Parameter(torch.zeros_like(getattr(self, f'weight_ih_l{l}'))))
            if self.bias:
                setattr(self, f'bias_hh_l{l}_1', nn.Parameter(torch.zeros_like(getattr(self, f'bias_hh_l{l}'))))
                setattr(self, f'bias_ih_l{l}_1', nn.Parameter(torch.zeros_like(getattr(self, f'bias_ih_l{l}'))))
                
class LinesLSTM(TwoParamLSTM):
    def get_weight(self):
        weight_dict = dict()
        for l in range(self.num_layers):
            weight_dict[f'weight_hh_l{l}_mixed'] = (1 - self.alpha) * getattr(self, f'weight_hh_l{l}') + self.alpha * getattr(self, f'weight_hh_l{l}_1') 
            weight_dict[f'weight_ih_l{l}_mixed'] = (1 - self.alpha) * getattr(self, f'weight_ih_l{l}') + self.alpha * getattr(self, f'weight_ih_l{l}_1') 
            if self.bias:
                weight_dict[f'bias_hh_l{l}_mixed'] = (1 - self.alpha) * getattr(self, f'bias_hh_l{l}') + self.alpha * getattr(self, f'bias_hh_l{l}_1') 
                weight_dict[f'bias_ih_l{l}_mixed'] = (1 - self.alpha) * getattr(self, f'bias_ih_l{l}') + self.alpha * getattr(self, f'bias_ih_l{l}_1')
        return weight_dict

# Embedding layer
class SubspaceEmbedding(nn.Embedding):
    def forward(self, x):
        w = self.get_weight()
        x = F.embedding(
            x,
            w,
        )
        return x

class TwoParamEmbedding(SubspaceEmbedding):
    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)
        self.weight1 = nn.Parameter(torch.zeros_like(self.weight))
                
class LinesEmbedding(TwoParamEmbedding):
    def get_weight(self):
        w = (1 - self.alpha) * self.weight + self.alpha * self.weight1
        return w

vaseline555 avatar Jan 28 '22 08:01 vaseline555

Hi Adam this looks great! I don't have access to this repo anymore because I'm not on the internship but let's keep this issue open so that other people can refer to this implementation!

In general we never looked at anything related to language or RNNs so very curious to see what you find.

Best, Mitchell

mitchellnw avatar Jan 28 '22 21:01 mitchellnw

And very nice paper -- thanks for sharing!

mitchellnw avatar Jan 28 '22 21:01 mitchellnw

Thank you for your checks! Okay, I will keep this issue opened. Have a nice weekend!

vaseline555 avatar Jan 29 '22 04:01 vaseline555

@mitchellnw Hi, I forgot to update into a correct implementation...! Here is the correct implementation on LSTM layer.

https://github.com/vaseline555/SuPerFed/blob/a1d54616b31af03634aaa71011c87d28124e1d56/src/models/layers.py#L74

For those who have interest, please refer to my repo. linked above. Thanks again Mitchell on his great work!

vaseline555 avatar Aug 18 '22 17:08 vaseline555