Advertisement
Guest User

Untitled

a guest
May 22nd, 2018
151
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
  1. import numpy as np
  2. import tensorflow as tf
  3. import pickle as pkl
  4. from functools import reduce
  5. from skimage.transform import resize
  6.  
  7.  
  8. def to_one_hot(y):
  9.     shape_one_hot = (np.size(y, 0), 36)
  10.     one_hot_vectors = np.zeros(shape_one_hot)
  11.     for i, label in enumerate(y):
  12.         one_hot_vectors[i, label] = 1
  13.     return one_hot_vectors
  14.  
  15. def rescale_images(x, new_shape):
  16.     with tf.Session() as sess:
  17.         to_resize = tf.image.resize_area(x.reshape([len(x), 56, 56, 1]), new_shape)
  18.         resized = sess.run(to_resize)
  19.     resized = resized.reshape([len(resized), reduce(lambda x, y: x * y, new_shape)])
  20.     return resized
  21.  
  22.  
  23. def next_batch(x_train, y_train, batch_size):
  24.     idx = np.arange(0, len(x_train))
  25.     np.random.shuffle(idx)
  26.     batch_idx = idx[:batch_size]
  27.  
  28.     batch_x = x_train[batch_idx]
  29.     batch_y = y_train[batch_idx]
  30.  
  31.     return batch_x, batch_y
  32.  
  33.  
  34. def get_accuracy(x_val, y_val, x, y, y_, sess, keep_prob):
  35.     correct_prediction = tf.equal(tf.argmax(y, 1), tf.argmax(y_, 1))
  36.     accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32))
  37.     return sess.run(accuracy, feed_dict={x: x_val, y_: y_val, keep_prob: 1.0})
  38.  
  39. def weight_variable(shape):
  40.   initial = tf.truncated_normal(shape, stddev=0.1)
  41.   return tf.Variable(initial)
  42.  
  43. def bias_variable(shape):
  44.   initial = tf.constant(0.1, shape=shape)
  45.   return tf.Variable(initial)
  46.  
  47. def conv2d(x, W):
  48.   return tf.nn.conv2d(x, W, strides=[1, 1, 1, 1], padding='SAME')
  49.  
  50. def max_pool_2x2(x):
  51.   return tf.nn.max_pool(x, ksize=[1, 2, 2, 1],
  52.                         strides=[1, 2, 2, 1], padding='SAME')
  53.  
  54.  
  55. def train_model(x_train, y_train, x_val, y_val):
  56.     # hyperparameters
  57.     batch_size = 50
  58.     epochs = 150
  59.  
  60.     # dimensions
  61.     input_shape = np.size(x_train, 1)
  62.     output_shape = 36
  63.  
  64.     x = tf.placeholder(tf.float32, [None, input_shape], name="network_input")
  65.     y_ = tf.placeholder(tf.float32, [None, output_shape], name="network_output")
  66.  
  67.     x_image = tf.reshape(x, [-1, 28, 28, 1])
  68.  
  69.     W_conv1 = weight_variable([5, 5, 1, 32])
  70.     b_conv1 = bias_variable([32])
  71.  
  72.     h_conv1 = tf.nn.relu(conv2d(x_image, W_conv1) + b_conv1)
  73.     h_pool1 = max_pool_2x2(h_conv1)
  74.  
  75.     W_conv2 = weight_variable([5, 5, 32, 64])
  76.     b_conv2 = bias_variable([64])
  77.  
  78.     h_conv2 = tf.nn.relu(conv2d(h_pool1, W_conv2) + b_conv2)
  79.     h_pool2 = max_pool_2x2(h_conv2)
  80.  
  81.     W_fc1 = weight_variable([7 * 7 * 64, 1024])
  82.     b_fc1 = bias_variable([1024])
  83.  
  84.     h_pool2_flat = tf.reshape(h_pool2, [-1, 7 * 7 * 64])
  85.     h_fc1 = tf.nn.relu(tf.matmul(h_pool2_flat, W_fc1) + b_fc1)
  86.  
  87.     keep_prob = tf.placeholder(tf.float32)
  88.     h_fc1_drop = tf.nn.dropout(h_fc1, keep_prob)
  89.  
  90.     W_fc2 = weight_variable([1024, 36])
  91.     b_fc2 = bias_variable([36])
  92.  
  93.     y_conv = tf.matmul(h_fc1_drop, W_fc2) + b_fc2
  94.  
  95.     cross_entropy = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(labels=y_, logits=y_conv))
  96.     optimizer = tf.train.AdamOptimizer(1e-4).minimize(cross_entropy)
  97.     correct_prediction = tf.equal(tf.argmax(y_conv, 1), tf.argmax(y_, 1))
  98.     accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32))
  99.     sess = tf.InteractiveSession()
  100.     sess.run(tf.global_variables_initializer())
  101.  
  102.     for epoch_idx in range(epochs):
  103.         for batch_idx in range(400):
  104.             batch_x, batch_y = next_batch(x_train, y_train, batch_size)
  105.             sess.run(optimizer, feed_dict={x: batch_x, y_: batch_y, keep_prob: 0.75})
  106.         val_accuracy = accuracy.eval(feed_dict={x: x_val, y_: y_val, keep_prob: 1.0})
  107.         print("Epoch %d, validation accuracy %g" % (epoch_idx, val_accuracy))
  108.  
  109.  
  110. def load_data(filename):
  111.     with open(filename, 'rb') as file:
  112.         x_train, y_train, x_val, y_val = pkl.load(file)
  113.     return x_train, y_train, x_val, y_val
  114.  
  115.  
  116. def save_data(filename, data):
  117.     with open(filename, 'wb') as output:
  118.         pkl.dump(data, output, pkl.HIGHEST_PROTOCOL)
  119.  
  120.  
  121. if __name__ == '__main__':
  122.     x_train, y_train, x_val, y_val = load_data('rotated-15_17.05.pkl')
  123.     y_train = to_one_hot(y_train)
  124.     y_val = to_one_hot(y_val)
  125.     x_train = rescale_images(x_train, [28, 28])
  126.     x_val = rescale_images(x_val, [28, 28])
  127.     train_model(x_train, y_train, x_val, y_val)
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement