Advertisement
Guest User

Untitled

a guest
Dec 12th, 2018
67
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
Python 6.79 KB | None | 0 0
  1. class Model(object):
  2.     def __init__(self, reversed_dict, article_max_len, summary_max_len, args, forward_only=False):
  3.         self.vocabulary_size = len(reversed_dict)
  4.         self.embedding_size = args['embedding_size']
  5.         self.num_hidden = args['num_hidden']
  6.         self.num_layers = args['num_layers']
  7.         self.learning_rate = args['learning_rate']
  8.         self.beam_width = args['beam_width']
  9.         if not forward_only:
  10.             self.keep_prob = args['keep_prob']
  11.         else:
  12.             self.keep_prob = 1.0
  13.         self.cell = tf.nn.rnn_cell.LSTMCell
  14.         with tf.variable_scope("decoder/projection"):
  15.             self.projection_layer = tf.layers.Dense(self.vocabulary_size, use_bias=False)
  16.  
  17.         self.batch_size = tf.placeholder(tf.int32, (), name="batch_size")
  18.         self.X = tf.placeholder(tf.int32, [None, article_max_len])
  19.         self.X_len = tf.placeholder(tf.int32, [None])
  20.         self.decoder_input = tf.placeholder(tf.int32, [None, summary_max_len])
  21.         self.decoder_len = tf.placeholder(tf.int32, [None])
  22.         self.decoder_target = tf.placeholder(tf.int32, [None, summary_max_len])
  23.         self.global_step = tf.Variable(0, trainable=False)
  24.  
  25.         with tf.name_scope("embedding"):
  26.             if not forward_only and args['w2v']:
  27.                 init_embeddings = tf.constant(get_init_embedding(reversed_dict, self.embedding_size), dtype=tf.float32)
  28.             else:
  29.                 init_embeddings = tf.random_uniform([self.vocabulary_size, self.embedding_size], -1.0, 1.0)
  30.             self.embeddings = tf.get_variable("embeddings", initializer=init_embeddings)
  31.             self.encoder_emb_inp = tf.transpose(tf.nn.embedding_lookup(self.embeddings, self.X), perm=[1, 0, 2])
  32.             self.decoder_emb_inp = tf.transpose(tf.nn.embedding_lookup(self.embeddings, self.decoder_input), perm=[1, 0, 2])
  33.  
  34.         with tf.name_scope("encoder"):
  35.             fw_cells = [self.cell(self.num_hidden) for _ in range(self.num_layers)]
  36.             bw_cells = [self.cell(self.num_hidden) for _ in range(self.num_layers)]
  37.             fw_cells = [rnn.DropoutWrapper(cell) for cell in fw_cells]
  38.             bw_cells = [rnn.DropoutWrapper(cell) for cell in bw_cells]
  39.  
  40.             encoder_outputs, encoder_state_fw, encoder_state_bw = tf.contrib.rnn.stack_bidirectional_dynamic_rnn(
  41.                 fw_cells, bw_cells, self.encoder_emb_inp,
  42.                 sequence_length=self.X_len, time_major=True, dtype=tf.float32)
  43.             self.encoder_output = tf.concat(encoder_outputs, 2)
  44.             encoder_state_c = tf.concat((encoder_state_fw[0].c, encoder_state_bw[0].c), 1)
  45.             encoder_state_h = tf.concat((encoder_state_fw[0].h, encoder_state_bw[0].h), 1)
  46.             self.encoder_state = rnn.LSTMStateTuple(c=encoder_state_c, h=encoder_state_h)
  47.  
  48.         with tf.name_scope("decoder"), tf.variable_scope("decoder") as decoder_scope:
  49.             decoder_cell = self.cell(self.num_hidden * 2)
  50.  
  51.             if not forward_only:
  52.                 attention_states = tf.transpose(self.encoder_output, [1, 0, 2])
  53.                 attention_mechanism = tf.contrib.seq2seq.BahdanauAttention(
  54.                     self.num_hidden * 2, attention_states, memory_sequence_length=self.X_len, normalize=True)
  55.                 decoder_cell = tf.contrib.seq2seq.AttentionWrapper(decoder_cell, attention_mechanism,
  56.                                                                    attention_layer_size=self.num_hidden * 2)
  57.                 initial_state = decoder_cell.zero_state(dtype=tf.float32, batch_size=self.batch_size)
  58.                 initial_state = initial_state.clone(cell_state=self.encoder_state)
  59.                 helper = tf.contrib.seq2seq.TrainingHelper(self.decoder_emb_inp, self.decoder_len, time_major=True)
  60.                 decoder = tf.contrib.seq2seq.BasicDecoder(decoder_cell, helper, initial_state)
  61.                 outputs, _, _ = tf.contrib.seq2seq.dynamic_decode(decoder, output_time_major=True, scope=decoder_scope)
  62.                 self.decoder_output = outputs.rnn_output
  63.                 self.logits = tf.transpose(
  64.                     self.projection_layer(self.decoder_output), perm=[1, 0, 2])
  65.                 self.logits_reshape = tf.concat(
  66.                     [self.logits, tf.zeros([self.batch_size, summary_max_len - tf.shape(self.logits)[1], self.vocabulary_size])], axis=1)
  67.             else:
  68.                 tiled_encoder_output = tf.contrib.seq2seq.tile_batch(
  69.                     tf.transpose(self.encoder_output, perm=[1, 0, 2]), multiplier=self.beam_width)
  70.                 tiled_encoder_final_state = tf.contrib.seq2seq.tile_batch(self.encoder_state, multiplier=self.beam_width)
  71.                 tiled_seq_len = tf.contrib.seq2seq.tile_batch(self.X_len, multiplier=self.beam_width)
  72.                 attention_mechanism = tf.contrib.seq2seq.BahdanauAttention(
  73.                     self.num_hidden * 2, tiled_encoder_output, memory_sequence_length=tiled_seq_len, normalize=True)
  74.                 decoder_cell = tf.contrib.seq2seq.AttentionWrapper(decoder_cell, attention_mechanism,
  75.                                                                    attention_layer_size=self.num_hidden * 2)
  76.                 initial_state = decoder_cell.zero_state(dtype=tf.float32, batch_size=self.batch_size * self.beam_width)
  77.                 initial_state = initial_state.clone(cell_state=tiled_encoder_final_state)
  78.                 decoder = tf.contrib.seq2seq.BeamSearchDecoder(
  79.                     cell=decoder_cell,
  80.                     embedding=self.embeddings,
  81.                     start_tokens=tf.fill([self.batch_size], tf.constant(2)),
  82.                     end_token=tf.constant(3),
  83.                     initial_state=initial_state,
  84.                     beam_width=self.beam_width,
  85.                     output_layer=self.projection_layer
  86.                 )
  87.                 outputs, _, _ = tf.contrib.seq2seq.dynamic_decode(
  88.                     decoder, output_time_major=True, maximum_iterations=summary_max_len, scope=decoder_scope)
  89.                 self.prediction = tf.transpose(outputs.predicted_ids, perm=[1, 2, 0])
  90.  
  91.         with tf.name_scope("loss"):
  92.             if not forward_only:
  93.                 crossent = tf.nn.sparse_softmax_cross_entropy_with_logits(
  94.                     logits=self.logits_reshape, labels=self.decoder_target)
  95.                 weights = tf.sequence_mask(self.decoder_len, summary_max_len, dtype=tf.float32)
  96.                 self.loss = tf.reduce_sum(crossent * weights / tf.to_float(self.batch_size))
  97.  
  98.                 params = tf.trainable_variables()
  99.                 gradients = tf.gradients(self.loss, params)
  100.                 clipped_gradients, _ = tf.clip_by_global_norm(gradients, 5.0)
  101.                 optimizer = tf.train.AdamOptimizer(self.learning_rate)
  102.                 self.update = optimizer.apply_gradients(zip(clipped_gradients, params), global_step=self.global_step)
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement