autoregressive icon indicating copy to clipboard operation
autoregressive copied to clipboard

Decouple input and output representations

Open cheind opened this issue 4 years ago • 2 comments

currently we assume that what we get as input is what we will predict as output (just shifted). However, thinking towards other research areas it might make sense that we rework that more generally:

model
  input: BxIxT
  output: BxQxT

where I might match Q but does not have to. In the training we would then have code like the following

def training_step(batch): 
  inputs = batch['x']
  if 't' in batch:
    targets = batch['t'] # allows us to provide alternative targets
  elif I == Q:
    targets = inputs[..., 1:]
    inputs = inputs[..., :-1]
  else:
    raise ValueError(...)

  logits = self.forward(inputs)
  loss = ce(logits, targets)

what's more is that we need to think about input transformers. Currently we use one-hot encoding hardwired into the model. We might instead consider a differentiable input_transform that is given to the model upon initialization. This would allow us to use differentiable embedding strategies.

cheind avatar Dec 10 '21 18:12 cheind



dataset -> model -> loss

model:
    input: BxIxT
    input_transform: fn(BxKxT) -> BxIxT
    condition: BxCxT
    output: BxQxT


def loss(inputs, outputs):
    if 't' in batch:
        targets = batch['t'][..., 1:] # BxQxT or BxT
    else:
        targets = batch['x'][..., 1:] # 'x' either BxQxT or BxT
    targets = inputs[..., 1:] # BxT
    logits = outputs[..., :-1]
    preds = sample(logits) # BxIxT
    ce(preds, targets) # BxQxT, BxQxT


def training_step(batch):
    inputs = batch['x'] # BxIxT
    condition = batch['c'] # BxCxT
    logits = self.forward(inputs)
    loss(...)

def forward(self, inputs, cond):
    inputs = self.input_transform(inputs)
    outputs = self.encode(inputs)

cheind avatar Dec 12 '21 14:12 cheind

would that also work for different model output interpretation such as #24

cheind avatar Dec 12 '21 14:12 cheind