MIGC icon indicating copy to clipboard operation
MIGC copied to clipboard

about the training code

Open AZZMM opened this issue 1 year ago • 14 comments

Hello, thanks for the excellent work, may I ask when the training code will be released, if it will be released soon, thank you!

AZZMM avatar May 05 '24 12:05 AZZMM

@AZZMM Thank you for your interest in our work. We are recently expanding MIGC to MIGC++ with more comprehensive functions. We expect to submit a report to arxiv in May 2024. We plan to release the training code together at that time.

limuloo avatar May 06 '24 01:05 limuloo

Thank you for quick response, look forward to the updates

AZZMM avatar May 06 '24 03:05 AZZMM

Excuse me, I've implemented the training code and I am confused about the details of the inhibition loss. I learn from the paper that the loss is based on the attention map in the frozen cross attention layer with shape(batch_size*head_dim, HW, text_len). May I ask how to deal with the text_len dimension and the head_dim?

I sum along the text_len dimension and get 1. I think it is because the softmax().

My implementation is like this: attn_map = attn_map.reshape(batch_size//head_size, head_size, HW, text_len)[2:, ...] # batch_size: instance_num+2 attn_map = attn_map.permute(0, 2, 1, 3).reshape(-1, HW, seq_len * head_size) avg = torch.where(background_masks, attn_map, 0).sum(dim=1) / background_masks.sum(dim=1) # avg shape: (instance_num, text_len*head_dim) ihbt_loss = (torch.abs(attn_map - avg[..., None, :]) * background_masks).sum() / background_masks.sum() Is this right?

Looking forward to your reply, thank you!

AZZMM avatar May 15 '24 12:05 AZZMM

pre_attn = fuser_info['pre_attn'] # (BPN, heads, HW, 77) BPN, heads, HW, _ = pre_attn.shape pre_attn = torch.sum(pre_attn[:, :, :, 1:], dim=-1) # (BPN, heads, HW) H = W = int(math.sqrt(HW)) pre_attn = pre_attn.view(bsz, -1, HW) # (B, PN*heads, HW)

supplement_mask_inter = F.interpolate(supplement_mask, (H, W), mode=args.inter_mode) # supplement_mask is the mask of BG supplement_mask_inter = supplement_mask_inter[:, 0, ...].view(bsz, 1, HW) # (B, 1, HW)

pre_attn_mean = (pre_attn * supplement_mask_inter).sum(dim=-1) /
((supplement_mask_inter).sum(dim=-1) + 1e-6) # (B, PN*heads) aug_scale = 1 now_pre_attn_loss = (abs(pre_attn - pre_attn_mean[..., None].detach()) * supplement_mask_inter).sum(dim=-1) /
(supplement_mask_inter.sum(dim=-1) + 1e-6) now_pre_attn_loss = (now_pre_attn_loss * aug_scale).mean() pre_attn_loss = pre_attn_loss + now_pre_attn_loss pre_attn_loss_cnt = pre_attn_loss_cnt + 1

@AZZMM You can refer to this code to implement inhibition loss. If you still have questions, you can ask me here.

limuloo avatar May 15 '24 13:05 limuloo

Thank you very much for your quick reply and the code is really helpful to me. May I ask what the BPN represents. Does it contain the negetive and global prompt attntion maps? And there are three attention maps in the 16*16 frozen cross attention layer. Is the final loss adding up the three results?

AZZMM avatar May 15 '24 16:05 AZZMM

@AZZMM In the training, we don't need negative prompts. BPN means Batch * Phase_num, the Phase_num contain {global prompt, instance1_desc, instance2_desc, ..., instanceN_desc}. We use the first two 16*16 attn-maps for calculating inhibition loss.

limuloo avatar May 15 '24 17:05 limuloo

Thank you very much for your help! I'll try it.

AZZMM avatar May 16 '24 12:05 AZZMM

Hello! @limuloo There is one detail that I am not very sure about the cross attention layer without migc. Is it use vanilla cross attention or naive fuser during training?

AZZMM avatar May 19 '24 11:05 AZZMM

@AZZMM vanilla cross attention

limuloo avatar May 19 '24 13:05 limuloo

I see, Thank you!

AZZMM avatar May 20 '24 01:05 AZZMM

@AZZMM Thank you for your interest in our work. We are recently expanding MIGC to MIGC++ with more comprehensive functions. We expect to submit a report to arxiv in May 2024. We plan to release the training code together at that time.

Hi, may I ask when will the MIGC++ be released?

WUyinwei-hah avatar Jun 27 '24 07:06 WUyinwei-hah

@WUyinwei-hah We have already completed the writing of the MIGC++ paper, and we will submit it in the next few days. Then, we will proceed to consider the open-source work for MIGC++.

limuloo avatar Jun 28 '24 02:06 limuloo

@WUyinwei-hah We have already completed the writing of the MIGC++ paper, and we will submit it in the next few days. Then, we will proceed to consider the open-source work for MIGC++.

Looking forward to it!

WUyinwei-hah avatar Jul 01 '24 02:07 WUyinwei-hah

@WUyinwei-hah We have already completed the writing of the MIGC++ paper, and we will submit it in the next few days. Then, we will proceed to consider the open-source work for MIGC++.

I

@WUyinwei-hah We have already completed the writing of the MIGC++ paper, and we will submit it in the next few days. Then, we will proceed to consider the open-source work for MIGC++.

Excuse me, I've implemented the training code and I am confused about the details of the inhibition loss. I learn from the paper that the loss is based on the attention map in the frozen cross attention layer with shape(batch_size*head_dim, HW, text_len). May I ask how to deal with the text_len dimension and the head_dim?

I sum along the text_len dimension and get 1. I think it is because the softmax().

My implementation is like this: attn_map = attn_map.reshape(batch_size//head_size, head_size, HW, text_len)[2:, ...] # batch_size: instance_num+2 attn_map = attn_map.permute(0, 2, 1, 3).reshape(-1, HW, seq_len * head_size) avg = torch.where(background_masks, attn_map, 0).sum(dim=1) / background_masks.sum(dim=1) # avg shape: (instance_num, text_len*head_dim) ihbt_loss = (torch.abs(attn_map - avg[..., None, :]) * background_masks).sum() / background_masks.sum() Is this right?

Looking forward to your reply, thank you!

hello, I've implemented the training code too, you could add my Wechat and we disucuss togethor ID:601663546

2120160608 avatar Oct 14 '24 08:10 2120160608