Advertisement
Guest User

Untitled

a guest
Jun 18th, 2019
106
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
Python 2.36 KB | None | 0 0
  1.   def _forward(self, x, y, model_params, init_states, is_training=False):
  2.     """Computes the logits.
  3.  
  4.    Args:
  5.      x: [batch_size, num_steps], input batch.
  6.      y: [batch_size, num_steps], output batch.
  7.      model_params: a `dict` of params to use.
  8.      init_states: a `dict` of params to use.
  9.      is_training: if `True`, will apply regularizations.
  10.  
  11.    Returns:
  12.      loss: scalar, cross-entropy loss
  13.    """
  14.     w_emb = model_params['w_emb']
  15.     w_prev = model_params['w_prev']
  16.     w_skip = model_params['w_skip']
  17.     w_soft = model_params['w_soft']
  18.     prev_s = init_states['s']
  19.  
  20.     emb = tf.nn.embedding_lookup(w_emb, x)
  21.     batch_size = self.params.batch_size
  22.     hidden_size = self.params.hidden_size
  23.     sample_arc = self.sample_arc
  24.     if is_training:
  25.       emb = tf.layers.dropout(
  26.           emb, self.params.drop_i, [batch_size, 1, hidden_size], training=True)
  27.  
  28.       input_mask = _gen_mask([batch_size, hidden_size], self.params.drop_x)
  29.       layer_mask = _gen_mask([batch_size, hidden_size], self.params.drop_l)
  30.     else:
  31.       input_mask = None
  32.       layer_mask = None
  33.  
  34.     out_s, all_s, var_s = _rnn_fn(sample_arc, emb, prev_s, w_prev, w_skip,
  35.                                   input_mask, layer_mask, params=self.params)
  36.  
  37.     top_s = all_s
  38.     if is_training:
  39.       top_s = tf.layers.dropout(
  40.           top_s, self.params.drop_o,
  41.           [self.params.batch_size, 1, self.params.hidden_size], training=True)
  42.  
  43.     carry_on = [tf.assign(prev_s, out_s)]
  44.     logits = tf.einsum('bnh,vh->bnv', top_s, w_soft)
  45.     loss = tf.nn.sparse_softmax_cross_entropy_with_logits(labels=y,
  46.                                                           logits=logits)
  47.     loss = tf.reduce_mean(loss)
  48.  
  49.     reg_loss = loss  # `loss + regularization_terms` is for training only
  50.     if is_training:
  51.       # L2 weight reg
  52.       self.l2_reg_loss = tf.add_n([tf.nn.l2_loss(w ** 2) for w in var_s])
  53.       reg_loss += self.params.weight_decay * self.l2_reg_loss
  54.  
  55.       # activation L2 reg
  56.       reg_loss += self.params.alpha * tf.reduce_mean(all_s ** 2)
  57.  
  58.       # activation slowness reg
  59.       reg_loss += self.params.beta * tf.reduce_mean(
  60.           (all_s[:, 1:, :] - all_s[:, :-1, :]) ** 2)
  61.  
  62.     with tf.control_dependencies(carry_on):
  63.       loss = tf.identity(loss)
  64.       if is_training:
  65.         reg_loss = tf.identity(reg_loss)
  66.  
  67.     return reg_loss, loss
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement