daily pastebin goal
62%
SHARE
TWEET

Untitled

a guest Mar 21st, 2019 62 Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
  1. """TensorFlow 2.0 implementation of vanilla Autoencoder."""
  2. from __future__ import absolute_import
  3. from __future__ import division
  4. from __future__ import print_function
  5.  
  6. __version__ = '0.0.1'
  7. __author__ = 'Abien Fred Agarap'
  8.  
  9. import numpy as np
  10. import tensorflow as tf
  11.  
  12. tf.random.set_seed(1)
  13. batch_size = 128
  14. epochs = 10
  15. learning_rate = 1e-3
  16. momentum = 9e-1
  17. intermediate_dim = 64
  18. original_dim = 784
  19.  
  20. (training_features, _), _ = tf.keras.datasets.mnist.load_data()
  21. training_features = training_features / np.max(training_features)
  22. training_features = training_features.reshape(training_features.shape[0],
  23.                                               training_features.shape[1] * training_features.shape[2]).astype(np.float32)
  24. training_dataset = tf.data.Dataset.from_tensor_slices(training_features).batch(batch_size)
  25.  
  26.  
  27. class Encoder(tf.keras.layers.Layer):
  28.   def __init__(self, intermediate_dim):
  29.     super(Encoder, self).__init__()
  30.     self.hidden_layer = tf.keras.layers.Dense(units=intermediate_dim, activation=tf.nn.relu)
  31.     self.output_layer = tf.keras.layers.Dense(units=intermediate_dim, activation=tf.nn.relu)
  32.    
  33.   def call(self, input_features):
  34.     activation = self.hidden_layer(input_features)
  35.     return self.output_layer(activation)
  36.  
  37. class Decoder(tf.keras.layers.Layer):
  38.   def __init__(self, intermediate_dim, original_dim):
  39.     super(Decoder, self).__init__()
  40.     self.hidden_layer = tf.keras.layers.Dense(units=intermediate_dim, activation=tf.nn.relu)
  41.     self.output_layer = tf.keras.layers.Dense(units=original_dim, activation=tf.nn.relu)
  42.  
  43.   def call(self, code):
  44.     activation = self.hidden_layer(code)
  45.     return self.output_layer(activation)
  46.  
  47. class Autoencoder(tf.keras.Model):
  48.   def __init__(self, intermediate_dim, original_dim):
  49.     super(Autoencoder, self).__init__()
  50.     self.encoder = Encoder(intermediate_dim=intermediate_dim)
  51.     self.decoder = Decoder(intermediate_dim=intermediate_dim, original_dim=original_dim)
  52.  
  53.   def call(self, input_features):
  54.     code = self.encoder(input_features)
  55.     reconstructed = self.decoder(code)
  56.     return reconstructed
  57.  
  58. autoencoder = Autoencoder(intermediate_dim=intermediate_dim, original_dim=original_dim)
  59. opt = tf.optimizers.SGD(learning_rate=learning_rate, momentum=momentum)
  60.  
  61. def loss(model, original):
  62.   reconstruction_error = tf.reduce_mean(tf.square(tf.subtract(model(original), original)))
  63.   return reconstruction_error
  64.  
  65. def train(loss, model, opt, original):
  66.   with tf.GradientTape() as tape:
  67.     gradients = tape.gradient(loss(model, original), model.trainable_variables)
  68.     gradient_variables = zip(gradients, model.trainable_variables)
  69.     opt.apply_gradients(gradient_variables)
  70.  
  71. writer = tf.summary.create_file_writer('tmp')
  72.  
  73. with writer.as_default():
  74.   with tf.summary.record_if(True):
  75.     for epoch in range(epochs):
  76.       for step, batch_features in enumerate(training_dataset):
  77.         train(loss, autoencoder, opt, batch_features)
  78.         loss_values = loss(autoencoder, batch_features)
  79.         original = tf.reshape(batch_features, (batch_features.shape[0], 28, 28, 1))
  80.         reconstructed = tf.reshape(autoencoder(tf.constant(batch_features)), (batch_features.shape[0], 28, 28, 1))
  81.         tf.summary.scalar('loss', loss_values, step=step)
  82.         tf.summary.image('original', original, max_outputs=10, step=step)
  83.         tf.summary.image('reconstructed', reconstructed, max_outputs=10, step=step)
RAW Paste Data
We use cookies for various purposes including analytics. By continuing to use Pastebin, you agree to our use of cookies as described in the Cookies Policy. OK, I Understand
 
Top