SHARE
TWEET

Untitled

a guest Jun 18th, 2019 66 Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
  1.   def _forward(self, x, y, model_params, init_states, is_training=False):
  2.     """Computes the logits.
  3.  
  4.    Args:
  5.      x: [batch_size, num_steps], input batch.
  6.      y: [batch_size, num_steps], output batch.
  7.      model_params: a `dict` of params to use.
  8.      init_states: a `dict` of params to use.
  9.      is_training: if `True`, will apply regularizations.
  10.  
  11.    Returns:
  12.      loss: scalar, cross-entropy loss
  13.    """
  14.     w_emb = model_params['w_emb']
  15.     w_prev = model_params['w_prev']
  16.     w_skip = model_params['w_skip']
  17.     w_soft = model_params['w_soft']
  18.     prev_s = init_states['s']
  19.  
  20.     emb = tf.nn.embedding_lookup(w_emb, x)
  21.     batch_size = self.params.batch_size
  22.     hidden_size = self.params.hidden_size
  23.     sample_arc = self.sample_arc
  24.     if is_training:
  25.       emb = tf.layers.dropout(
  26.           emb, self.params.drop_i, [batch_size, 1, hidden_size], training=True)
  27.  
  28.       input_mask = _gen_mask([batch_size, hidden_size], self.params.drop_x)
  29.       layer_mask = _gen_mask([batch_size, hidden_size], self.params.drop_l)
  30.     else:
  31.       input_mask = None
  32.       layer_mask = None
  33.  
  34.     out_s, all_s, var_s = _rnn_fn(sample_arc, emb, prev_s, w_prev, w_skip,
  35.                                   input_mask, layer_mask, params=self.params)
  36.  
  37.     top_s = all_s
  38.     if is_training:
  39.       top_s = tf.layers.dropout(
  40.           top_s, self.params.drop_o,
  41.           [self.params.batch_size, 1, self.params.hidden_size], training=True)
  42.  
  43.     carry_on = [tf.assign(prev_s, out_s)]
  44.     logits = tf.einsum('bnh,vh->bnv', top_s, w_soft)
  45.     loss = tf.nn.sparse_softmax_cross_entropy_with_logits(labels=y,
  46.                                                           logits=logits)
  47.     loss = tf.reduce_mean(loss)
  48.  
  49.     reg_loss = loss  # `loss + regularization_terms` is for training only
  50.     if is_training:
  51.       # L2 weight reg
  52.       self.l2_reg_loss = tf.add_n([tf.nn.l2_loss(w ** 2) for w in var_s])
  53.       reg_loss += self.params.weight_decay * self.l2_reg_loss
  54.  
  55.       # activation L2 reg
  56.       reg_loss += self.params.alpha * tf.reduce_mean(all_s ** 2)
  57.  
  58.       # activation slowness reg
  59.       reg_loss += self.params.beta * tf.reduce_mean(
  60.           (all_s[:, 1:, :] - all_s[:, :-1, :]) ** 2)
  61.  
  62.     with tf.control_dependencies(carry_on):
  63.       loss = tf.identity(loss)
  64.       if is_training:
  65.         reg_loss = tf.identity(reg_loss)
  66.  
  67.     return reg_loss, loss
RAW Paste Data
We use cookies for various purposes including analytics. By continuing to use Pastebin, you agree to our use of cookies as described in the Cookies Policy. OK, I Understand
Not a member of Pastebin yet?
Sign Up, it unlocks many cool features!
 
Top