SHARE
TWEET

Untitled

a guest May 22nd, 2018 100 Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
  1. import numpy as np
  2. import tensorflow as tf
  3. import pickle as pkl
  4. from functools import reduce
  5. from skimage.transform import resize
  6.  
  7.  
  8. def to_one_hot(y):
  9.     shape_one_hot = (np.size(y, 0), 36)
  10.     one_hot_vectors = np.zeros(shape_one_hot)
  11.     for i, label in enumerate(y):
  12.         one_hot_vectors[i, label] = 1
  13.     return one_hot_vectors
  14.  
  15. def rescale_images(x, new_shape):
  16.     with tf.Session() as sess:
  17.         to_resize = tf.image.resize_area(x.reshape([len(x), 56, 56, 1]), new_shape)
  18.         resized = sess.run(to_resize)
  19.     resized = resized.reshape([len(resized), reduce(lambda x, y: x * y, new_shape)])
  20.     return resized
  21.  
  22.  
  23. def next_batch(x_train, y_train, batch_size):
  24.     idx = np.arange(0, len(x_train))
  25.     np.random.shuffle(idx)
  26.     batch_idx = idx[:batch_size]
  27.  
  28.     batch_x = x_train[batch_idx]
  29.     batch_y = y_train[batch_idx]
  30.  
  31.     return batch_x, batch_y
  32.  
  33.  
  34. def get_accuracy(x_val, y_val, x, y, y_, sess, keep_prob):
  35.     correct_prediction = tf.equal(tf.argmax(y, 1), tf.argmax(y_, 1))
  36.     accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32))
  37.     return sess.run(accuracy, feed_dict={x: x_val, y_: y_val, keep_prob: 1.0})
  38.  
  39. def weight_variable(shape):
  40.   initial = tf.truncated_normal(shape, stddev=0.1)
  41.   return tf.Variable(initial)
  42.  
  43. def bias_variable(shape):
  44.   initial = tf.constant(0.1, shape=shape)
  45.   return tf.Variable(initial)
  46.  
  47. def conv2d(x, W):
  48.   return tf.nn.conv2d(x, W, strides=[1, 1, 1, 1], padding='SAME')
  49.  
  50. def max_pool_2x2(x):
  51.   return tf.nn.max_pool(x, ksize=[1, 2, 2, 1],
  52.                         strides=[1, 2, 2, 1], padding='SAME')
  53.  
  54.  
  55. def train_model(x_train, y_train, x_val, y_val):
  56.     # hyperparameters
  57.     batch_size = 50
  58.     epochs = 150
  59.  
  60.     # dimensions
  61.     input_shape = np.size(x_train, 1)
  62.     output_shape = 36
  63.  
  64.     x = tf.placeholder(tf.float32, [None, input_shape], name="network_input")
  65.     y_ = tf.placeholder(tf.float32, [None, output_shape], name="network_output")
  66.  
  67.     x_image = tf.reshape(x, [-1, 28, 28, 1])
  68.  
  69.     W_conv1 = weight_variable([5, 5, 1, 32])
  70.     b_conv1 = bias_variable([32])
  71.  
  72.     h_conv1 = tf.nn.relu(conv2d(x_image, W_conv1) + b_conv1)
  73.     h_pool1 = max_pool_2x2(h_conv1)
  74.  
  75.     W_conv2 = weight_variable([5, 5, 32, 64])
  76.     b_conv2 = bias_variable([64])
  77.  
  78.     h_conv2 = tf.nn.relu(conv2d(h_pool1, W_conv2) + b_conv2)
  79.     h_pool2 = max_pool_2x2(h_conv2)
  80.  
  81.     W_fc1 = weight_variable([7 * 7 * 64, 1024])
  82.     b_fc1 = bias_variable([1024])
  83.  
  84.     h_pool2_flat = tf.reshape(h_pool2, [-1, 7 * 7 * 64])
  85.     h_fc1 = tf.nn.relu(tf.matmul(h_pool2_flat, W_fc1) + b_fc1)
  86.  
  87.     keep_prob = tf.placeholder(tf.float32)
  88.     h_fc1_drop = tf.nn.dropout(h_fc1, keep_prob)
  89.  
  90.     W_fc2 = weight_variable([1024, 36])
  91.     b_fc2 = bias_variable([36])
  92.  
  93.     y_conv = tf.matmul(h_fc1_drop, W_fc2) + b_fc2
  94.  
  95.     cross_entropy = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(labels=y_, logits=y_conv))
  96.     optimizer = tf.train.AdamOptimizer(1e-4).minimize(cross_entropy)
  97.     correct_prediction = tf.equal(tf.argmax(y_conv, 1), tf.argmax(y_, 1))
  98.     accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32))
  99.     sess = tf.InteractiveSession()
  100.     sess.run(tf.global_variables_initializer())
  101.  
  102.     for epoch_idx in range(epochs):
  103.         for batch_idx in range(400):
  104.             batch_x, batch_y = next_batch(x_train, y_train, batch_size)
  105.             sess.run(optimizer, feed_dict={x: batch_x, y_: batch_y, keep_prob: 0.75})
  106.         val_accuracy = accuracy.eval(feed_dict={x: x_val, y_: y_val, keep_prob: 1.0})
  107.         print("Epoch %d, validation accuracy %g" % (epoch_idx, val_accuracy))
  108.  
  109.  
  110. def load_data(filename):
  111.     with open(filename, 'rb') as file:
  112.         x_train, y_train, x_val, y_val = pkl.load(file)
  113.     return x_train, y_train, x_val, y_val
  114.  
  115.  
  116. def save_data(filename, data):
  117.     with open(filename, 'wb') as output:
  118.         pkl.dump(data, output, pkl.HIGHEST_PROTOCOL)
  119.  
  120.  
  121. if __name__ == '__main__':
  122.     x_train, y_train, x_val, y_val = load_data('rotated-15_17.05.pkl')
  123.     y_train = to_one_hot(y_train)
  124.     y_val = to_one_hot(y_val)
  125.     x_train = rescale_images(x_train, [28, 28])
  126.     x_val = rescale_images(x_val, [28, 28])
  127.     train_model(x_train, y_train, x_val, y_val)
RAW Paste Data
We use cookies for various purposes including analytics. By continuing to use Pastebin, you agree to our use of cookies as described in the Cookies Policy. OK, I Understand
Not a member of Pastebin yet?
Sign Up, it unlocks many cool features!
 
Top