Advertisement
Guest User

Untitled

a guest
Jan 22nd, 2020
92
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
text 1.71 KB | None | 0 0
  1. #we have concatenated the negative features and did a forward pass and can slice out the corresponding parts of of cap_logits and targets
  2. #PAD=0
  3. #SOS=1
  4. #EOS=2
  5. cap_crit = nn.CrossEntropyLoss()
  6. pos_cap_logits = cap_logits[:batch_size,:]
  7. target_cap = target_cap.long().view(-1)
  8. pos_target = target_cap[:batch_size]
  9. orig_pos_target = pos_target
  10. mask =pos_target.ge(1.1)# mask out padding
  11. pos_target=torch.masked_select(pos_target, mask) #(batch,)
  12. pos_cap_logits=pos_cap_logits[mask,:] #(batch, vocab)
  13. pos_caption_loss = cap_crit(pos_cap_logits, pos_target.long().view(-1))
  14.  
  15. visual_unpaired_cap_logits = cap_logits[batch_size: batch_size*2, :]
  16. vp_target = target_cap[batch_size: batch_size*2]
  17. mask =vp_target.ge(1.1)# mask out padding
  18. vp_target=torch.masked_select(vp_target, mask) #(batch,)
  19. visual_unpaired_cap_logits=visual_unpaired_cap_logits[mask,:] #(batch, vocab)
  20. vp_caption_loss = cap_crit(visual_unpaired_cap_logits, vp_target.long().view(-1))
  21.  
  22. lang_unpaired_cap_logits = cap_logits[batch_size*2:, :]
  23. lup_target = target_cap[batch_size*2:]
  24. mask =lup_target.ge(1.1)# mask out padding
  25. lup_target=torch.masked_select(lup_target, mask) #(batch,)
  26. lang_unpaired_cap_logits=lang_unpaired_cap_logits[mask,:] #(batch, vocab)
  27. lup_caption_loss = cap_crit(lang_unpaired_cap_logits, lup_target.long().view(-1))
  28. margin=1.0
  29. visual_rank_weight = 1.0
  30. lang_rank_weight = 0.1
  31. visual_rank_loss = visual_rank_weight * torch.clamp(margin + visual_unpaired - pos_caption_loss, 0)
  32. lang_rank_loss = lang_rank_weight * torch.clamp(margin + lang_unpaired - pos_caption_loss, 0)
  33. mmi_loss = (visual_rank_loss + lang_rank_loss)
  34.  
  35. generation_loss = pos_caption_loss + mmi_loss
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement