Advertisement
Guest User

Keras version

a guest
Jan 29th, 2020
105
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
text 3.13 KB | None | 0 0
  1. import numpy as np
  2. from numpy import hstack
  3. from numpy import zeros
  4. from numpy import ones
  5. from numpy.random import rand
  6. from numpy.random import randn
  7. from keras.models import Sequential
  8. from keras.layers import Dense
  9. import keras.backend as K
  10. from matplotlib import pyplot
  11.  
  12. class GanPointGraph_Keras(object):
  13.  
  14. def __init__(self):
  15. self.latent_dim = 5
  16. self.discriminator = self.define_discriminator()
  17. self.generator = self.define_generator(self.latent_dim)
  18. self.gan_model = self.define_gan(self.generator, self.discriminator)
  19.  
  20. def define_discriminator(self, n_inputs=2):
  21. model = Sequential()
  22. model.add(Dense(25, activation='relu', kernel_initializer='he_uniform', input_dim=n_inputs))
  23. model.add(Dense(1, activation='sigmoid'))
  24. model.compile(loss='binary_crossentropy', optimizer='adam', metrics=['accuracy'])
  25. print(K.eval(model.optimizer.lr))
  26. return model
  27.  
  28. def define_generator(self, latent_dim, n_outputs=2):
  29. model = Sequential()
  30. model.add(Dense(15, activation='relu', kernel_initializer='he_uniform', input_dim=latent_dim))
  31. model.add(Dense(n_outputs, activation='linear'))
  32. return model
  33.  
  34. def define_gan(self, generator, discriminator):
  35. discriminator.trainable = False
  36. model = Sequential()
  37. model.add(generator)
  38. model.add(discriminator)
  39. model.compile(loss='binary_crossentropy', optimizer='adam')
  40. return model
  41.  
  42. def generate_latent_points(self, n):
  43. x_input = randn(self.latent_dim * n)
  44. x_input = x_input.reshape(n, self.latent_dim)
  45. return x_input
  46.  
  47. def generate_fake_samples(self, n):
  48. x_input = self.generate_latent_points(n)
  49. X = self.generator.predict(x_input)
  50. return X
  51.  
  52. def generate_real_samples(self, n):
  53. X1 = rand(n) - 0.5
  54. X2 = X1 * X1
  55. X1 = X1.reshape(n, 1)
  56. X2 = X2.reshape(n, 1)
  57. X = hstack((X1, X2))
  58. return X
  59.  
  60. def train(self):
  61. n_batch = 128
  62. half_batch = int(n_batch / 2)
  63. x_real = self.generate_real_samples(half_batch)
  64. y_real = ones((half_batch, 1))
  65. x_fake = self.generate_fake_samples(half_batch)
  66. y_fake = zeros((half_batch, 1))
  67. self.discriminator.train_on_batch(x_real, y_real)
  68. self.discriminator.train_on_batch(x_fake, y_fake)
  69. x_gan = self.generate_latent_points(n_batch)
  70. y_gan = ones((n_batch, 1))
  71. self.gan_model.train_on_batch(x_gan, y_gan)
  72.  
  73. if __name__ == "__main__":
  74. g = GanPointGraph_Keras();
  75.  
  76. for epoch in range(10000):
  77. print('Epoch', epoch)
  78. g.train()
  79. if epoch % 1000 == 0:
  80. g_objects = g.generate_fake_samples(100)
  81. r_objects = g.generate_real_samples(100)
  82.  
  83. pyplot.clf()
  84. pyplot.title('Keras iteration ' + str(epoch))
  85. pyplot.scatter([i[0] for i in r_objects], [i[1] for i in r_objects], c='black')
  86. pyplot.scatter([i[0] for i in g_objects], [i[1] for i in g_objects], c='red')
  87. pyplot.show()
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement