FastSpeech2 icon indicating copy to clipboard operation
FastSpeech2 copied to clipboard

[suggestion] a simple LengthRegulator

Open hccho2 opened this issue 4 years ago • 4 comments

class LengthRegulator(nn.Module):
    """ Length Regulator """

    def __init__(self):
        super(LengthRegulator, self).__init__()

    def LR(self,x,duration,max_len):
        N, T, D = x.shape  # batch size
  
        unbatch_x = x.reshape(-1,D)
        unbatch_duration = duration.flatten().long()
        output = torch.repeat_interleave(unbatch_x,unbatch_duration,axis=0)
        output = output.reshape(N,-1,D)
          
        mel_len = duration.sum(-1)
        if max_len is not None:
            output = pad(output, max_len)
        else:
            output = pad(output)
  
        return output, mel_len.long()

    def forward(self, x, duration, max_len):
        output, mel_len = self.LR(x, duration, max_len)
        return output, mel_len

hccho2 avatar Apr 27 '21 11:04 hccho2

@hccho2 Do you expect that the suggestion to make the model jit traceable?

KinamSalad avatar Apr 30 '21 09:04 KinamSalad

@KinamSalad I wasn't thinking of a jit decoration. I only considered a simple implementation.

hccho2 avatar Apr 30 '21 22:04 hccho2

@hccho2 Do you expect that the suggestion to make the model jit traceable?

See #35 for a JIT traceable length regulator

xDuck avatar May 06 '21 14:05 xDuck

Using repeat_interleave to replace for is a good idea. But could output.reshape(N,-1,D) get the right result since different samples in a batch have different lengths? Maybe replacing for with repeat_interleave in expand function is better. It can speed up during training and inference.

    def expand(self, batch, predicted):
        """
        Args:
            batch (torch.float32): shape: (TextLen, Hidden)
            predicted (torch.int64): shape: (TextLen,)
        """
        # out = list()
        # for i, vec in enumerate(batch):
        #     expand_size = predicted[i].item()
        #     out.append(vec.expand(max(int(expand_size), 0), -1))
        # out = torch.cat(out, 0)

        out = torch.repeat_interleave(batch, predicted.long().clamp(min=0), axis=0)
        return out

hellolzc avatar Jul 12 '21 09:07 hellolzc