Advertisement
Not a member of Pastebin yet?
Sign Up,
it unlocks many cool features!
- #we have concatenated the negative features and did a forward pass and can slice out the corresponding parts of of cap_logits and targets
- #PAD=0
- #SOS=1
- #EOS=2
- cap_crit = nn.CrossEntropyLoss()
- pos_cap_logits = cap_logits[:batch_size,:]
- target_cap = target_cap.long().view(-1)
- pos_target = target_cap[:batch_size]
- orig_pos_target = pos_target
- mask =pos_target.ge(1.1)# mask out padding
- pos_target=torch.masked_select(pos_target, mask) #(batch,)
- pos_cap_logits=pos_cap_logits[mask,:] #(batch, vocab)
- pos_caption_loss = cap_crit(pos_cap_logits, pos_target.long().view(-1))
- visual_unpaired_cap_logits = cap_logits[batch_size: batch_size*2, :]
- vp_target = target_cap[batch_size: batch_size*2]
- mask =vp_target.ge(1.1)# mask out padding
- vp_target=torch.masked_select(vp_target, mask) #(batch,)
- visual_unpaired_cap_logits=visual_unpaired_cap_logits[mask,:] #(batch, vocab)
- vp_caption_loss = cap_crit(visual_unpaired_cap_logits, vp_target.long().view(-1))
- lang_unpaired_cap_logits = cap_logits[batch_size*2:, :]
- lup_target = target_cap[batch_size*2:]
- mask =lup_target.ge(1.1)# mask out padding
- lup_target=torch.masked_select(lup_target, mask) #(batch,)
- lang_unpaired_cap_logits=lang_unpaired_cap_logits[mask,:] #(batch, vocab)
- lup_caption_loss = cap_crit(lang_unpaired_cap_logits, lup_target.long().view(-1))
- margin=1.0
- visual_rank_weight = 1.0
- lang_rank_weight = 0.1
- visual_rank_loss = visual_rank_weight * torch.clamp(margin + visual_unpaired - pos_caption_loss, 0)
- lang_rank_loss = lang_rank_weight * torch.clamp(margin + lang_unpaired - pos_caption_loss, 0)
- mmi_loss = (visual_rank_loss + lang_rank_loss)
- generation_loss = pos_caption_loss + mmi_loss
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement