about the training code
Hello, thanks for the excellent work, may I ask when the training code will be released, if it will be released soon, thank you!
@AZZMM Thank you for your interest in our work. We are recently expanding MIGC to MIGC++ with more comprehensive functions. We expect to submit a report to arxiv in May 2024. We plan to release the training code together at that time.
Thank you for quick response, look forward to the updates
Excuse me, I've implemented the training code and I am confused about the details of the inhibition loss. I learn from the paper that the loss is based on the attention map in the frozen cross attention layer with shape(batch_size*head_dim, HW, text_len). May I ask how to deal with the text_len dimension and the head_dim?
I sum along the text_len dimension and get 1. I think it is because the softmax().
My implementation is like this:
attn_map = attn_map.reshape(batch_size//head_size, head_size, HW, text_len)[2:, ...] # batch_size: instance_num+2
attn_map = attn_map.permute(0, 2, 1, 3).reshape(-1, HW, seq_len * head_size)
avg = torch.where(background_masks, attn_map, 0).sum(dim=1) / background_masks.sum(dim=1)
# avg shape: (instance_num, text_len*head_dim)
ihbt_loss = (torch.abs(attn_map - avg[..., None, :]) * background_masks).sum() / background_masks.sum()
Is this right?
Looking forward to your reply, thank you!
pre_attn = fuser_info['pre_attn'] # (BPN, heads, HW, 77) BPN, heads, HW, _ = pre_attn.shape pre_attn = torch.sum(pre_attn[:, :, :, 1:], dim=-1) # (BPN, heads, HW) H = W = int(math.sqrt(HW)) pre_attn = pre_attn.view(bsz, -1, HW) # (B, PN*heads, HW)
supplement_mask_inter = F.interpolate(supplement_mask, (H, W), mode=args.inter_mode) # supplement_mask is the mask of BG supplement_mask_inter = supplement_mask_inter[:, 0, ...].view(bsz, 1, HW) # (B, 1, HW)
pre_attn_mean = (pre_attn * supplement_mask_inter).sum(dim=-1) /
((supplement_mask_inter).sum(dim=-1) + 1e-6) # (B, PN*heads)
aug_scale = 1
now_pre_attn_loss = (abs(pre_attn - pre_attn_mean[..., None].detach()) * supplement_mask_inter).sum(dim=-1) /
(supplement_mask_inter.sum(dim=-1) + 1e-6)
now_pre_attn_loss = (now_pre_attn_loss * aug_scale).mean()
pre_attn_loss = pre_attn_loss + now_pre_attn_loss
pre_attn_loss_cnt = pre_attn_loss_cnt + 1
@AZZMM You can refer to this code to implement inhibition loss. If you still have questions, you can ask me here.
Thank you very much for your quick reply and the code is really helpful to me.
May I ask what the BPN represents. Does it contain the negetive and global prompt attntion maps?
And there are three attention maps in the 16*16 frozen cross attention layer. Is the final loss adding up the three results?
@AZZMM In the training, we don't need negative prompts. BPN means Batch * Phase_num, the Phase_num contain {global prompt, instance1_desc, instance2_desc, ..., instanceN_desc}. We use the first two 16*16 attn-maps for calculating inhibition loss.
Thank you very much for your help! I'll try it.
Hello! @limuloo There is one detail that I am not very sure about the cross attention layer without migc. Is it use vanilla cross attention or naive fuser during training?
@AZZMM vanilla cross attention
I see, Thank you!
@AZZMM Thank you for your interest in our work. We are recently expanding MIGC to MIGC++ with more comprehensive functions. We expect to submit a report to arxiv in May 2024. We plan to release the training code together at that time.
Hi, may I ask when will the MIGC++ be released?
@WUyinwei-hah We have already completed the writing of the MIGC++ paper, and we will submit it in the next few days. Then, we will proceed to consider the open-source work for MIGC++.
@WUyinwei-hah We have already completed the writing of the MIGC++ paper, and we will submit it in the next few days. Then, we will proceed to consider the open-source work for MIGC++.
Looking forward to it!
@WUyinwei-hah We have already completed the writing of the MIGC++ paper, and we will submit it in the next few days. Then, we will proceed to consider the open-source work for MIGC++.
I
@WUyinwei-hah We have already completed the writing of the MIGC++ paper, and we will submit it in the next few days. Then, we will proceed to consider the open-source work for MIGC++.
Excuse me, I've implemented the training code and I am confused about the details of the inhibition loss. I learn from the paper that the loss is based on the attention map in the frozen cross attention layer with shape
(batch_size*head_dim, HW, text_len). May I ask how to deal with thetext_lendimension and thehead_dim?I sum along the
text_lendimension and get 1. I think it is because thesoftmax().My implementation is like this:
attn_map = attn_map.reshape(batch_size//head_size, head_size, HW, text_len)[2:, ...] # batch_size: instance_num+2attn_map = attn_map.permute(0, 2, 1, 3).reshape(-1, HW, seq_len * head_size)avg = torch.where(background_masks, attn_map, 0).sum(dim=1) / background_masks.sum(dim=1)# avg shape: (instance_num, text_len*head_dim)ihbt_loss = (torch.abs(attn_map - avg[..., None, :]) * background_masks).sum() / background_masks.sum()Is this right?Looking forward to your reply, thank you!
hello, I've implemented the training code too, you could add my Wechat and we disucuss togethor ID:601663546