Advertisement
Not a member of Pastebin yet?
Sign Up,
it unlocks many cool features!
- def _forward(self, x, y, model_params, init_states, is_training=False):
- """Computes the logits.
- Args:
- x: [batch_size, num_steps], input batch.
- y: [batch_size, num_steps], output batch.
- model_params: a `dict` of params to use.
- init_states: a `dict` of params to use.
- is_training: if `True`, will apply regularizations.
- Returns:
- loss: scalar, cross-entropy loss
- """
- w_emb = model_params['w_emb']
- w_prev = model_params['w_prev']
- w_skip = model_params['w_skip']
- w_soft = model_params['w_soft']
- prev_s = init_states['s']
- emb = tf.nn.embedding_lookup(w_emb, x)
- batch_size = self.params.batch_size
- hidden_size = self.params.hidden_size
- sample_arc = self.sample_arc
- if is_training:
- emb = tf.layers.dropout(
- emb, self.params.drop_i, [batch_size, 1, hidden_size], training=True)
- input_mask = _gen_mask([batch_size, hidden_size], self.params.drop_x)
- layer_mask = _gen_mask([batch_size, hidden_size], self.params.drop_l)
- else:
- input_mask = None
- layer_mask = None
- out_s, all_s, var_s = _rnn_fn(sample_arc, emb, prev_s, w_prev, w_skip,
- input_mask, layer_mask, params=self.params)
- top_s = all_s
- if is_training:
- top_s = tf.layers.dropout(
- top_s, self.params.drop_o,
- [self.params.batch_size, 1, self.params.hidden_size], training=True)
- carry_on = [tf.assign(prev_s, out_s)]
- logits = tf.einsum('bnh,vh->bnv', top_s, w_soft)
- loss = tf.nn.sparse_softmax_cross_entropy_with_logits(labels=y,
- logits=logits)
- loss = tf.reduce_mean(loss)
- reg_loss = loss # `loss + regularization_terms` is for training only
- if is_training:
- # L2 weight reg
- self.l2_reg_loss = tf.add_n([tf.nn.l2_loss(w ** 2) for w in var_s])
- reg_loss += self.params.weight_decay * self.l2_reg_loss
- # activation L2 reg
- reg_loss += self.params.alpha * tf.reduce_mean(all_s ** 2)
- # activation slowness reg
- reg_loss += self.params.beta * tf.reduce_mean(
- (all_s[:, 1:, :] - all_s[:, :-1, :]) ** 2)
- with tf.control_dependencies(carry_on):
- loss = tf.identity(loss)
- if is_training:
- reg_loss = tf.identity(reg_loss)
- return reg_loss, loss
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement