Advertisement
Guest User

Untitled

a guest
Apr 5th, 2020
294
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
Python 4.04 KB | None | 0 0
  1. import operator
  2. from queue import PriorityQueue
  3.  
  4. class BeamSearchNode(object):
  5.     def __init__(self, hiddenstate, previousNode, wordId, logProb, length):
  6.         '''
  7.        :param hiddenstate:
  8.        :param previousNode:
  9.        :param wordId:
  10.        :param logProb:
  11.        :param length:
  12.        '''
  13.         self.h = hiddenstate
  14.         self.prevNode = previousNode
  15.         self.wordid = wordId
  16.         self.logp = logProb
  17.         self.leng = length
  18.  
  19.     def eval(self, alpha=1.0):
  20.         reward = 0
  21.         # Add here a function for shaping a reward
  22.  
  23.         return self.logp / float(self.leng - 1 + 1e-6) + alpha * reward
  24.    
  25. def beam_decode(target_tensor, decoder_hiddens, encoder_outputs, decoder):
  26.     '''
  27.    :param target_tensor: target indexes tensor of shape [B, T] where B is the batch size and T is the maximum length of the output sentence
  28.    :param decoder_hidden: input tensor of shape [1, B, H] for start of the decoding
  29.    :param encoder_outputs: if you are using attention mechanism you can pass encoder outputs, [T, B, H] where T is the maximum length of input sentence
  30.    :return: decoded_words
  31.    '''
  32.  
  33.     beam_size = 5
  34.     beam_width = beam_size
  35.     topk = 1  # how many sentence do you want to generate
  36.     decoded_batch = []
  37.  
  38.     # decoding goes sentence by sentence
  39.     for idx in range(target_tensor.size(0)):
  40.         # Start with the start of the sentence token
  41.         decoder_input = torch.LongTensor([[SOS_token]], device=device)
  42.  
  43.         # Number of sentence to generate
  44.         endnodes = []
  45.         number_required = min((topk + 1), topk - len(endnodes))
  46.  
  47.         # starting node -  hidden vector, previous node, word id, logp, length
  48.         node = BeamSearchNode(decoder_hiddens, None, decoder_input, 0, 1)
  49.         nodes = PriorityQueue()
  50.  
  51.         # start the queue
  52.         nodes.put((-node.eval(), node))
  53.         qsize = 1
  54.  
  55.         # start beam search
  56.         while True:
  57.             # give up when decoding takes too long
  58.             if qsize > 2000: break
  59.  
  60.             # fetch the best node
  61.             score, n = nodes.get()
  62.             decoder_input = n.wordid
  63.             decoder_hidden = n.h
  64.  
  65.             if n.wordid.item() == EOS_token and n.prevNode != None:
  66.                 endnodes.append((score, n))
  67.                 # if we reached maximum # of sentences required
  68.                 if len(endnodes) >= number_required:
  69.                     break
  70.                 else:
  71.                     continue
  72.  
  73.             # decode for one step using decoder
  74.             decoder_output, decoder_hidden, _ = decoder(decoder_input, decoder_hidden, encoder_outputs)
  75.  
  76.             # PUT HERE REAL BEAM SEARCH OF TOP
  77.             log_prob, indexes = torch.topk(decoder_output, beam_width)
  78.             nextnodes = []
  79.  
  80.             for new_k in range(beam_width):
  81.                 decoded_t = indexes[0][new_k].view(1, -1)
  82.                 log_p = log_prob[0][new_k].item()
  83.  
  84.                 node = BeamSearchNode(decoder_hidden, n, decoded_t, n.logp + log_p, n.leng + 1)
  85.                 score = -node.eval()
  86.                 nextnodes.append((score, node))
  87.  
  88.             # put them into queue
  89.             for i in range(len(nextnodes)):
  90.                 score, nn = nextnodes[i]
  91.                 nodes.put((score, nn))
  92.                 # increase qsize
  93.             qsize += len(nextnodes) - 1
  94.  
  95.         # choose nbest paths, back trace them
  96.         if len(endnodes) == 0:
  97.             endnodes = [nodes.get() for _ in range(topk)]
  98.  
  99.         utterances = []
  100.         for score, n in sorted(endnodes, key=operator.itemgetter(0)):
  101.             utterance = []
  102.             utterance.append(n.wordid)
  103.             # back trace
  104.             while n.prevNode != None:
  105.                 n = n.prevNode
  106.                 utterance.append(n.wordid)
  107.  
  108.             utterance = utterance[::-1]
  109.             utterances.append(utterance)
  110.        
  111.         utterances = list(itertools.chain.from_iterable(utterances))
  112.         decoded_words = [output_lang.index2word[word.item()] for word in utterances]
  113.  
  114.     return decoded_words
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement