Advertisement
Guest User

Untitled

a guest
Aug 19th, 2019
99
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
text 2.44 KB | None | 0 0
  1. import tensorflow as tf
  2. import numpy as np
  3. import random
  4. from multiprocessing import pool as mpp
  5. import itertools as it
  6. import matplotlib.pyplot as plt
  7.  
  8.  
  9. def onehotencode(index, n):
  10. return [1.0 if i == index else 0.0 for i in range(n)]
  11.  
  12.  
  13. if __name__ == "__main__":
  14. (x_train_data, y_train_data), (test_x_data, test_y_data) = tf.keras.datasets.mnist.load_data()
  15.  
  16. x_train_data = np.reshape(x_train_data, (-1, 28 * 28))
  17. test_x_data = np.reshape(test_x_data, (-1, 28 * 28))
  18.  
  19. print(x_train_data.shape)
  20. print(test_x_data.shape)
  21.  
  22. pool = mpp.Pool(8)
  23. y_train_data = np.array(pool.starmap(onehotencode, zip(y_train_data, it.repeat(10))))
  24. test_y_data = np.array(pool.starmap(onehotencode, zip(test_y_data, it.repeat(10))))
  25. pool.close()
  26. pool.join()
  27.  
  28. input_layer = tf.keras.layers.Input((784,))
  29. hidden_layer_1 = tf.keras.layers.Dense(units=10, activation=tf.keras.activations.tanh)(input_layer)
  30. hidden_layer_2 = tf.keras.layers.Dense(units=10, activation=tf.keras.activations.tanh)(hidden_layer_1)
  31. hidden_layer_3 = tf.keras.layers.Dense(units=10, activation=tf.keras.activations.tanh)(hidden_layer_2)
  32. output_layer = tf.keras.layers.Dense(units=10, activation=tf.keras.activations.softmax)(hidden_layer_3)
  33.  
  34. model = tf.keras.Model(inputs=input_layer, outputs=output_layer)
  35. model.compile(tf.keras.optimizers.SGD(0.01), tf.keras.losses.mean_squared_error, ["accuracy"])
  36.  
  37. plt.ion()
  38. plt.show()
  39.  
  40. sampleCount = 15
  41. sampleIndicies = random.sample(range(test_x_data.shape[0]), sampleCount)
  42. sample = test_x_data[sampleIndicies]
  43. sampleview = np.reshape(sample, (-1, 28, 28, 1))
  44. sampleview = np.tile(sampleview, (1, 1, 1, 3))
  45.  
  46. fig, ax = plt.subplots(sampleCount, 2)
  47.  
  48.  
  49. def showplt(epoch, logs):
  50. if epoch % 2 == 0:
  51. classifications = model.predict_on_batch(sample)
  52. for i in range(sampleCount):
  53. ax[i][0].cla() # clear the plot
  54. ax[i][1].cla() # clear the plot
  55.  
  56. ax[i][0].imshow(sampleview[i])
  57. ax[i][1].bar(range(10), classifications[i])
  58.  
  59. plt.draw()
  60. plt.pause(0.0001)
  61.  
  62.  
  63. pltcallback = tf.keras.callbacks.LambdaCallback(on_epoch_end=showplt)
  64. model.fit(x_train_data, y_train_data, 300, 1000, validation_data=[test_x_data, test_y_data],
  65. callbacks=[])
  66.  
  67. # dropout_layer3 = tf.keras.layers.Dropout(0.4)(hidden_layer_3)+
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement