Time2Vec-PyTorch icon indicating copy to clipboard operation
Time2Vec-PyTorch copied to clipboard

why not write it using `nn.Linear`?

Open crazydogen opened this issue 2 years ago • 1 comments

Here is a simple version.

import torch
import torch.nn as nn

class Time2vec(nn.Module):
    def __init__(self, c_in, c_out, activation="cos"):
        super().__init__()
        self.wnbn = nn.Linear(c_in, c_out - 1, bias=True)
        self.w0b0 = nn.Linear(c_in, 1, bias=True)
        self.act = torch.cos if activation == "cos" else torch.sin

    def forward(self, x):
        part0 = self.act(self.w0b0(x))
        # print(part0.shape)
        part1 = self.act(self.wnbn(x))
        # print(part1.shape)
        return torch.cat([part0, part1], -1)


if __name__ == "__main__":
    test_x = torch.randn((1, 3, 3000))  # [N, C, L] -> batch, channel, length
    m = Time2vec(3, 10)
    out = m(test_x.permute(0,2,1))
    print(out.shape)

crazydogen avatar Jun 08 '23 07:06 crazydogen

(Sorry for responding in someone else's repo)

Seems reasonable, but note that the sin (or cos) is only applied to part1, not part0 (see here or equation (1) in the paper).

schuemie avatar Aug 29 '23 06:08 schuemie