Advertisement
jack06215

[keras] GAN customised train_step()

Aug 2nd, 2020
219
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
Python 6.53 KB | None | 0 0
  1. from tensorflow.keras import layers
  2. from tensorflow import keras
  3. from tensorflow.keras.layers import Input, Dense, LeakyReLU, Activation
  4. from tensorflow.keras.models import Model, Sequential
  5. from tensorflow.keras.losses import BinaryCrossentropy
  6. import tensorflow as tf
  7. import numpy as np
  8. import pandas as pd
  9.  
  10. # !gdown --id 1-o-1k5OWe1yhsuJDnr8bhHTPz9SwF5T6
  11. df = pd.read_csv('creditcard.csv')
  12. df = df[df['Class'] == 0].reset_index(drop=True)
  13. df_raw = df.drop(["Class", 'Time'], axis = 1)
  14. df_raw = df_raw.iloc[:,:28].values
  15. del df
  16.  
  17. feat_dim = df_raw.shape[1]
  18. batch_size = 700
  19.  
  20. df_raw_v = np.reshape(df_raw.astype(np.float32), (-1, feat_dim))
  21. dataset = tf.data.Dataset.from_tensor_slices(df_raw_v)
  22. dataset = dataset.shuffle(buffer_size=512).batch(batch_size)
  23.  
  24. class MyDiscriminator(Model):
  25.   def __init__(self, d_hidden_dim=50, last_activation='tanh', **kwargs):
  26.     super().__init__(**kwargs) # handle standard args (e.g. name)
  27.     self.hidden1 = Dense(d_hidden_dim * 2, name='discriminator_h1', activation=LeakyReLU(0.3))
  28.     self.hidden2 = Dense(d_hidden_dim, name='discriminator_h2', activation=LeakyReLU(0.3))
  29.     self.d_output = Dense(1, name='discriminator_y', activation=last_activation)
  30.  
  31.   def call(self, inputs, with_feature=False):
  32.     hidden1 = self.hidden1(inputs)
  33.     d_latent_feat = self.hidden2(hidden1)
  34.     d_output = self.d_output(d_latent_feat)
  35.  
  36.     if with_feature:
  37.       return d_output, d_latent_feat
  38.     else:
  39.       return d_output
  40.  
  41. class MyGenerator(Model):
  42.   def __init__(self, output_dim, d_hidden_dim=100, last_activation='tanh', **kwargs):
  43.     super().__init__(**kwargs) # handle standard args (e.g. name)
  44.     self.hidden1 = Dense(d_hidden_dim, name="generator_h1", activation=LeakyReLU(0.3))
  45.     self.g_output = Dense(output_dim, name='generator_x_flat', activation=last_activation)
  46.  
  47.   def call(self, inputs, with_feature=False):
  48.     g_latent_feat = self.hidden1(inputs)
  49.     g_output = self.g_output(g_latent_feat)
  50.  
  51.     if with_feature:
  52.       return g_output, g_latent_feat
  53.     else:
  54.       return g_output
  55.  
  56. # Create the discriminator
  57. d_hidden_dim = 50
  58.  
  59. # Create the generator
  60. g_latent_dim = 100
  61. g_output_dim = feat_dim
  62. g_noise_dim = 10
  63. my_disc = Sequential(
  64.           [
  65.               Input(shape=df_raw.shape[1]),
  66.               MyDiscriminator()
  67.           ], name='discriminator_def'
  68. )
  69.  
  70. my_gen = Sequential(
  71.           [
  72.               Input(shape=g_noise_dim),
  73.               MyGenerator(df_raw.shape[1])
  74.           ], name='generator_def'
  75. )
  76.  
  77. my_disc.summary()
  78.  
  79. discriminator = Sequential(
  80.     [
  81.         # Input layer
  82.         Input(shape=(g_output_dim,)),
  83.        
  84.         # Layer 1
  85.         Dense(d_hidden_dim * 2, name="discriminator_h1"),
  86.         LeakyReLU(0.2),
  87.        
  88.         # Layer 2
  89.         Dense(d_hidden_dim, name="discriminator_h2"),
  90.         LeakyReLU(0.2),
  91.        
  92.         # Output layer
  93.         Dense(1, name="discriminator_y"),
  94.         Activation("tanh"),
  95.     ],
  96.     name="discriminator",
  97. )
  98.  
  99.  
  100. generator = keras.Sequential(
  101.     [
  102.       # Input layer
  103.       Input(shape=(g_noise_dim,)),
  104.      
  105.       # Layer 1
  106.       Dense(int(g_latent_dim), name="generator_h1"),
  107.       LeakyReLU(0.2),
  108.      
  109.       # Output layer
  110.       Dense(g_output_dim, name="generator_x_flat"),
  111.       Activation('tanh'),
  112.     ],
  113.     name="generator",
  114. )
  115.  
  116. class GAN(keras.Model):
  117.     def __init__(self, discriminator, generator, latent_dim):
  118.         super(GAN, self).__init__()
  119.         self.discriminator = discriminator
  120.         self.generator = generator
  121.         self.latent_dim = latent_dim
  122.  
  123.     def compile(self, d_optimizer, g_optimizer, d_loss_fn, g_loss_fn):
  124.         super(GAN, self).compile()
  125.         self.d_optimizer = d_optimizer
  126.         self.g_optimizer = g_optimizer
  127.         self.d_loss_fn = d_loss_fn
  128.         self.g_loss_fn = g_loss_fn
  129.  
  130.     def train_step(self, real_images):
  131.         # Unpack tf.dataset
  132.         if isinstance(real_images, tuple):
  133.             real_images = real_images[0]
  134.        
  135.         # Sample random points in the latent space
  136.         batch_size = tf.shape(real_images)[0]
  137.         random_latent_vectors = tf.random.normal(shape=(batch_size, self.latent_dim))
  138.        
  139.         # Decode them to fake images
  140.         generated_images = self.generator(random_latent_vectors)
  141.  
  142.         # Train the discriminator
  143.         with tf.GradientTape() as tape:
  144.             predictions = self.discriminator(generated_images)
  145.             predictions_r = self.discriminator(real_images)
  146.             d_loss = self.d_loss_fn(predictions_r, predictions)
  147.        
  148.         grads = tape.gradient(d_loss, self.discriminator.trainable_weights)
  149.         self.d_optimizer.apply_gradients(zip(grads, self.discriminator.trainable_weights))
  150.  
  151.         # Sample random points in the latent space
  152.         random_latent_vectors = tf.random.normal(shape=(batch_size, self.latent_dim))
  153.  
  154.         # Train the generator (note that we should *not* update the weights of the discriminator!)
  155.         with tf.GradientTape() as tape:
  156.             predictions = self.discriminator(self.generator(random_latent_vectors))
  157.             g_loss = self.g_loss_fn(predictions)
  158.        
  159.         grads = tape.gradient(g_loss, self.generator.trainable_weights)
  160.         self.g_optimizer.apply_gradients(zip(grads, self.generator.trainable_weights))
  161.         return {"d_loss": d_loss, "g_loss": g_loss}
  162.  
  163. # Jensen–Shannon divergence
  164. def discriminator_loss(d_real, d_fake, metrics='JSD'):
  165.     if metrics in ['JSD', 'jsd']:
  166.         real_loss = tf.reduce_mean(
  167.             tf.nn.sigmoid_cross_entropy_with_logits(labels=tf.ones_like(d_real),
  168.                                                     logits=d_real))
  169.         fake_loss = tf.reduce_mean(
  170.             tf.nn.sigmoid_cross_entropy_with_logits(labels=tf.zeros_like(d_fake),
  171.                                                     logits=d_fake))
  172.         return real_loss + fake_loss
  173.     else:
  174.       raise ValueError
  175.  
  176. def generator_loss(d_fake, metrics='JSD'):
  177.     if metrics in ['JSD', 'jsd']:
  178.         return tf.reduce_mean(
  179.             tf.nn.sigmoid_cross_entropy_with_logits(labels=tf.ones_like(d_fake),
  180.                                                     logits=d_fake))
  181.     else:
  182.         raise ValueError
  183.  
  184. gan = GAN(discriminator=discriminator, generator=generator, latent_dim=g_noise_dim)
  185. # gan = GAN(discriminator=my_disc, generator=my_gen, latent_dim=g_noise_dim)
  186. gan.compile(d_optimizer=keras.optimizers.Adam(), g_optimizer=keras.optimizers.Adam(),
  187.             d_loss_fn=discriminator_loss, g_loss_fn=generator_loss)
  188.  
  189. train_result = gan.fit(dataset, epochs=20, verbose=1)
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement