[suggestion] a simple LengthRegulator
class LengthRegulator(nn.Module):
""" Length Regulator """
def __init__(self):
super(LengthRegulator, self).__init__()
def LR(self,x,duration,max_len):
N, T, D = x.shape # batch size
unbatch_x = x.reshape(-1,D)
unbatch_duration = duration.flatten().long()
output = torch.repeat_interleave(unbatch_x,unbatch_duration,axis=0)
output = output.reshape(N,-1,D)
mel_len = duration.sum(-1)
if max_len is not None:
output = pad(output, max_len)
else:
output = pad(output)
return output, mel_len.long()
def forward(self, x, duration, max_len):
output, mel_len = self.LR(x, duration, max_len)
return output, mel_len
@hccho2 Do you expect that the suggestion to make the model jit traceable?
@KinamSalad I wasn't thinking of a jit decoration. I only considered a simple implementation.
@hccho2 Do you expect that the suggestion to make the model jit traceable?
See #35 for a JIT traceable length regulator
Using repeat_interleave to replace for is a good idea. But could output.reshape(N,-1,D) get the right result since different samples in a batch have different lengths? Maybe replacing for with repeat_interleave in expand function is better. It can speed up during training and inference.
def expand(self, batch, predicted):
"""
Args:
batch (torch.float32): shape: (TextLen, Hidden)
predicted (torch.int64): shape: (TextLen,)
"""
# out = list()
# for i, vec in enumerate(batch):
# expand_size = predicted[i].item()
# out.append(vec.expand(max(int(expand_size), 0), -1))
# out = torch.cat(out, 0)
out = torch.repeat_interleave(batch, predicted.long().clamp(min=0), axis=0)
return out