Advertisement
Not a member of Pastebin yet?
Sign Up,
it unlocks many cool features!
- import numpy as np
- from numpy import hstack
- from numpy import zeros
- from numpy import ones
- from numpy.random import rand
- from numpy.random import randn
- from keras.models import Sequential
- from keras.layers import Dense
- import keras.backend as K
- from matplotlib import pyplot
- class GanPointGraph_Keras(object):
- def __init__(self):
- self.latent_dim = 5
- self.discriminator = self.define_discriminator()
- self.generator = self.define_generator(self.latent_dim)
- self.gan_model = self.define_gan(self.generator, self.discriminator)
- def define_discriminator(self, n_inputs=2):
- model = Sequential()
- model.add(Dense(25, activation='relu', kernel_initializer='he_uniform', input_dim=n_inputs))
- model.add(Dense(1, activation='sigmoid'))
- model.compile(loss='binary_crossentropy', optimizer='adam', metrics=['accuracy'])
- print(K.eval(model.optimizer.lr))
- return model
- def define_generator(self, latent_dim, n_outputs=2):
- model = Sequential()
- model.add(Dense(15, activation='relu', kernel_initializer='he_uniform', input_dim=latent_dim))
- model.add(Dense(n_outputs, activation='linear'))
- return model
- def define_gan(self, generator, discriminator):
- discriminator.trainable = False
- model = Sequential()
- model.add(generator)
- model.add(discriminator)
- model.compile(loss='binary_crossentropy', optimizer='adam')
- return model
- def generate_latent_points(self, n):
- x_input = randn(self.latent_dim * n)
- x_input = x_input.reshape(n, self.latent_dim)
- return x_input
- def generate_fake_samples(self, n):
- x_input = self.generate_latent_points(n)
- X = self.generator.predict(x_input)
- return X
- def generate_real_samples(self, n):
- X1 = rand(n) - 0.5
- X2 = X1 * X1
- X1 = X1.reshape(n, 1)
- X2 = X2.reshape(n, 1)
- X = hstack((X1, X2))
- return X
- def train(self):
- n_batch = 128
- half_batch = int(n_batch / 2)
- x_real = self.generate_real_samples(half_batch)
- y_real = ones((half_batch, 1))
- x_fake = self.generate_fake_samples(half_batch)
- y_fake = zeros((half_batch, 1))
- self.discriminator.train_on_batch(x_real, y_real)
- self.discriminator.train_on_batch(x_fake, y_fake)
- x_gan = self.generate_latent_points(n_batch)
- y_gan = ones((n_batch, 1))
- self.gan_model.train_on_batch(x_gan, y_gan)
- if __name__ == "__main__":
- g = GanPointGraph_Keras();
- for epoch in range(10000):
- print('Epoch', epoch)
- g.train()
- if epoch % 1000 == 0:
- g_objects = g.generate_fake_samples(100)
- r_objects = g.generate_real_samples(100)
- pyplot.clf()
- pyplot.title('Keras iteration ' + str(epoch))
- pyplot.scatter([i[0] for i in r_objects], [i[1] for i in r_objects], c='black')
- pyplot.scatter([i[0] for i in g_objects], [i[1] for i in g_objects], c='red')
- pyplot.show()
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement