Live2Diff icon indicating copy to clipboard operation
Live2Diff copied to clipboard

Inquiring about training codes

Open xuanxu92 opened this issue 1 year ago • 3 comments

Thanks for the excellent work! could you please release the training code when you are available?

xuanxu92 avatar Oct 02 '24 14:10 xuanxu92

Hey @xuanxu92, thanks for your interest in our project.

Our training code is based on PIA. We re-implement the attention operation of Motion Module and use the following attention mask during temporal attention.

def make_tril_block_mask(video_length: int, patch_size: int, device):
    """
    tensor([[[1., 1., 0., 0.],
             [1., 1., 0., 0.],
             [1., 1., 1., 0.],
             [1., 1., 1., 1.]]])
    """
    tmp_mask = torch.zeros(video_length, video_length)

    # warmup steps
    for idx in range(patch_size):
        tmp_mask[idx, :patch_size] = 1
    # tril blocks
    for idx in range(patch_size, video_length):
        tmp_mask[idx, :idx + 1] = 1

    tmp_mask = tmp_mask.type(torch.bool)
    mask = torch.zeros_like(tmp_mask, dtype=torch.float)
    mask.masked_fill_(tmp_mask.logical_not(), float('-inf'))
    return mask.to(device)

LeoXing1996 avatar Oct 28 '24 03:10 LeoXing1996

Hey @xuanxu92, thanks for your interest in our project.

Our training code is based on PIA. We re-implement the attention operation of Motion Module and use the following attention mask during temporal attention.

def make_tril_block_mask(video_length: int, patch_size: int, device):
    """
    tensor([[[1., 1., 0., 0.],
             [1., 1., 0., 0.],
             [1., 1., 1., 0.],
             [1., 1., 1., 1.]]])
    """
    tmp_mask = torch.zeros(video_length, video_length)

    # warmup steps
    for idx in range(patch_size):
        tmp_mask[idx, :patch_size] = 1
    # tril blocks
    for idx in range(patch_size, video_length):
        tmp_mask[idx, :idx + 1] = 1

    tmp_mask = tmp_mask.type(torch.bool)
    mask = torch.zeros_like(tmp_mask, dtype=torch.float)
    mask.masked_fill_(tmp_mask.logical_not(), float('-inf'))
    return mask.to(device)

Hey, Thank you for your tips. Besides, I just want to try live2diff with uni-directional attention and wonder what will happen. Is the training set of the results show in Live2diff Figure 3(d) same as the warp-up uni-directional attention training set? e.g. 3000 steps, batchsize=1024, lr=1e-4.

somuchtome avatar Oct 28 '24 14:10 somuchtome

Hey @xuanxu92 , sorry for late response. I check the history of your comment. For the historical comments: If you apply "full-uni-directional" attention (i.e., causal attention used in LLMs), it's understandable that the initial frames may become stuck, as the first few frames in Live2Diff are trained with "bi-directional" attention.

For the current comment, the answer is "yes."

LeoXing1996 avatar Nov 09 '24 03:11 LeoXing1996