Not a member of Pastebin yet?
Sign Up,
it unlocks many cool features!
- def generator():
- with tf.variable_scope('generator'):
- logits = tf.get_variable('logits', initializer=tf.ones([len(number_to_prob)]))
- gumbel_dist = tf.contrib.distributions.RelaxedOneHotCategorical(TEMPERATURE, logits=logits)
- probs = tf.nn.softmax(logits)
- generated = gumbel_dist.sample(BATCH_SIZE)
- return generated, probs
Add Comment
Please, Sign In to add comment