Advertisement
Guest User

Untitled

a guest
Jul 20th, 2019
79
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
text 2.07 KB | None | 0 0
  1. def train_autoencoder(X_train, n_neurons, n_epochs, batch_size,
  2. learning_rate = 0.01, l2_reg = 0.0005, seed=42,
  3. hidden_activation=tf.nn.elu,
  4. output_activation=tf.nn.elu):
  5. graph = tf.Graph()
  6. with graph.as_default():
  7. tf.set_random_seed(seed)
  8.  
  9. n_inputs = X_train.shape[1]
  10.  
  11. X = tf.placeholder(tf.float32, shape=[None, n_inputs])
  12.  
  13. my_dense_layer = partial(
  14. tf.layers.dense,
  15. kernel_initializer=tf.contrib.layers.variance_scaling_initializer(),
  16. kernel_regularizer=tf.contrib.layers.l2_regularizer(l2_reg))
  17.  
  18. hidden = my_dense_layer(X, n_neurons, activation=hidden_activation, name="hidden")
  19. outputs = my_dense_layer(hidden, n_inputs, activation=output_activation, name="outputs")
  20.  
  21. reconstruction_loss = tf.reduce_mean(tf.square(outputs - X))
  22.  
  23. reg_losses = tf.get_collection(tf.GraphKeys.REGULARIZATION_LOSSES)
  24. loss = tf.add_n([reconstruction_loss] + reg_losses)
  25.  
  26. optimizer = tf.train.AdamOptimizer(learning_rate)
  27. training_op = optimizer.minimize(loss)
  28.  
  29. init = tf.global_variables_initializer()
  30.  
  31. with tf.Session(graph=graph) as sess:
  32. init.run()
  33. for epoch in range(n_epochs):
  34. n_batches = len(X_train) // batch_size
  35. for iteration in range(n_batches):
  36. print("\r{}%".format(100 * iteration // n_batches), end="")
  37. sys.stdout.flush()
  38. indices = rnd.permutation(len(X_train))[:batch_size]
  39. X_batch = X_train[indices]
  40. sess.run(training_op, feed_dict={X: X_batch})
  41. loss_train = reconstruction_loss.eval(feed_dict={X: X_batch})
  42. print("\r{}".format(epoch), "Train MSE:", loss_train)
  43. params = dict([(var.name, var.eval()) for var in tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES)])
  44. hidden_val = hidden.eval(feed_dict={X: X_train})
  45. return hidden_val, params["hidden/kernel:0"], params["hidden/bias:0"], params["outputs/kernel:0"], params["outputs/bias:0"]
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement