Advertisement
Guest User

Untitled

a guest
Dec 7th, 2016
141
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
text 11.14 KB | None | 0 0
  1. """Tutorial on how to create a convolutional autoencoder w/ Tensorflow.
  2.  
  3. Parag K. Mital, Jan 2016
  4. """
  5. import io
  6. import math
  7. import os
  8. import time
  9.  
  10. import matplotlib.pyplot as plt
  11. import numpy as np
  12. import tensorflow as tf
  13. from PIL import Image
  14. from tensorflow.python.framework import ops
  15. from tensorflow.python.ops import gen_nn_ops
  16.  
  17. from libs.activations import lrelu
  18. from libs.utils import corrupt
  19.  
  20. np.set_printoptions(threshold=np.nan)
  21.  
  22.  
  23. @ops.RegisterGradient("MaxPoolWithArgmax")
  24. def _MaxPoolWithArgmaxGrad(op, grad, unused_argmax_grad):
  25. return gen_nn_ops._max_pool_grad(op.inputs[0],
  26. op.outputs[0],
  27. grad,
  28. op.get_attr("ksize"),
  29. op.get_attr("strides"),
  30. padding=op.get_attr("padding"),
  31. data_format='NHWC')
  32.  
  33.  
  34. def unravel_argmax(argmax, shape):
  35. output_list = [argmax // (shape[2] * shape[3]),
  36. argmax % (shape[2] * shape[3]) // shape[3]]
  37. return tf.pack(output_list)
  38.  
  39.  
  40. def unpool_layer2x2_batch(bottom, argmax):
  41. bottom_shape = tf.shape(bottom)
  42. top_shape = [bottom_shape[0], bottom_shape[1] * 2, bottom_shape[2] * 2, bottom_shape[3]]
  43.  
  44. batch_size = top_shape[0]
  45. height = top_shape[1]
  46. width = top_shape[2]
  47. channels = top_shape[3]
  48.  
  49. argmax_shape = tf.to_int64([batch_size, height, width, channels])
  50. argmax = unravel_argmax(argmax, argmax_shape)
  51.  
  52. t1 = tf.to_int64(tf.range(channels))
  53. t1 = tf.tile(t1, [batch_size * (width // 2) * (height // 2)])
  54. t1 = tf.reshape(t1, [-1, channels])
  55. t1 = tf.transpose(t1, perm=[1, 0])
  56. t1 = tf.reshape(t1, [channels, batch_size, height // 2, width // 2, 1])
  57. t1 = tf.transpose(t1, perm=[1, 0, 2, 3, 4])
  58.  
  59. t2 = tf.to_int64(tf.range(batch_size))
  60. t2 = tf.tile(t2, [channels * (width // 2) * (height // 2)])
  61. t2 = tf.reshape(t2, [-1, batch_size])
  62. t2 = tf.transpose(t2, perm=[1, 0])
  63. t2 = tf.reshape(t2, [batch_size, channels, height // 2, width // 2, 1])
  64.  
  65. t3 = tf.transpose(argmax, perm=[1, 4, 2, 3, 0])
  66.  
  67. t = tf.concat(4, [t2, t3, t1])
  68. indices = tf.reshape(t, [(height // 2) * (width // 2) * channels * batch_size, 4])
  69.  
  70. x1 = tf.transpose(bottom, perm=[0, 3, 1, 2])
  71. values = tf.reshape(x1, [-1])
  72.  
  73. delta = tf.SparseTensor(indices, values, tf.to_int64(top_shape))
  74. return tf.sparse_tensor_to_dense(tf.sparse_reorder(delta))
  75.  
  76.  
  77. class Network:
  78. IMAGE_HEIGHT = 250
  79. IMAGE_WIDTH = 250
  80. IMAGE_CHANNELS = 1
  81.  
  82. def __init__(self,
  83. n_filters=[1, 10],
  84. filter_sizes=[3, 3],
  85. corruption=False):
  86. """Build a deep denoising autoencoder w/ tied weights.
  87.  
  88. Parameters
  89. ----------
  90. input_shape : list, optional
  91. Description
  92. n_filters : list, optional
  93. Description
  94. filter_sizes : list, optional
  95. Description
  96.  
  97. Raises
  98. ------
  99. ValueError
  100. Description
  101. """
  102. # %%
  103. # input to the network
  104. self.inputs = tf.placeholder(tf.float32, [None, self.IMAGE_HEIGHT, self.IMAGE_WIDTH, self.IMAGE_CHANNELS], name='x')
  105. self.targets = tf.placeholder(tf.float32, [None, self.IMAGE_HEIGHT, self.IMAGE_WIDTH, 1], name='x')
  106.  
  107. current_input = self.inputs
  108.  
  109. # Optionally apply denoising autoencoder
  110. if corruption:
  111. current_input = corrupt(current_input)
  112.  
  113. # Build the encoder
  114. encoder = []
  115. shapes = []
  116. for layer_index, output_channels in enumerate(n_filters[1:]):
  117. number_of_channels = current_input.get_shape().as_list()[3]
  118. shapes.append(current_input.get_shape().as_list())
  119.  
  120. W = tf.get_variable('W' + str(layer_index), shape=(filter_sizes[layer_index], filter_sizes[layer_index], number_of_channels, output_channels))
  121. b = tf.Variable(tf.zeros([output_channels]))
  122. encoder.append(W)
  123. output = lrelu(tf.add(tf.nn.conv2d(current_input, W, strides=[1, 2, 2, 1], padding='SAME'), b))
  124. current_input = output
  125.  
  126. current_input, argmax_1 = tf.nn.max_pool_with_argmax(current_input, ksize=[1, 2, 2, 1], strides=[1, 2, 2, 1], padding='SAME', name='pool1')
  127.  
  128. # store the latent representation
  129. # z = current_input
  130.  
  131. current_input = unpool_layer2x2_batch(current_input, argmax_1)
  132.  
  133. encoder.reverse()
  134. shapes.reverse()
  135.  
  136. # Build the decoder using the same weights
  137. for layer_index, shape in enumerate(shapes):
  138. W = encoder[layer_index]
  139. b = tf.Variable(tf.zeros([W.get_shape().as_list()[2]]))
  140. output = lrelu(tf.add(
  141. tf.nn.conv2d_transpose(
  142. current_input, W,
  143. tf.pack([tf.shape(self.inputs)[0], shape[1], shape[2], shape[3]]),
  144. strides=[1, 2, 2, 1], padding='SAME'), b))
  145. current_input = output
  146.  
  147. current_input = tf.sigmoid(current_input)
  148.  
  149. self.segmentation_result = current_input # [batch_size, self.IMAGE_HEIGHT, self.IMAGE_WIDTH, self.IMAGE_CHANNELS]
  150.  
  151. # segmentation_as_classes = tf.reshape(self.y, [50 * self.IMAGE_HEIGHT * self.IMAGE_WIDTH, 1])
  152. # targets_as_classes = tf.reshape(self.targets, [50 * self.IMAGE_HEIGHT * self.IMAGE_WIDTH])
  153. # print(self.y.get_shape())
  154. # self.cost = tf.reduce_mean(tf.nn.sparse_softmax_cross_entropy_with_logits(segmentation_as_classes, targets_as_classes))
  155.  
  156. # MSE loss
  157. self.cost = tf.sqrt(tf.reduce_mean(tf.square(self.segmentation_result - self.targets)))
  158.  
  159. self.train_op = tf.train.AdamOptimizer().minimize(self.cost)
  160.  
  161.  
  162. class Dataset:
  163. def __init__(self, folder='data28_28', batch_size=50):
  164. self.batch_size = batch_size
  165.  
  166. self.train_inputs = []
  167. self.train_targets = []
  168. self.test_inputs = []
  169. self.test_targets = []
  170.  
  171. train_files = []
  172. test_files = []
  173. with io.open(os.path.join(folder, 'train.txt'), 'r') as reader:
  174. train_files += reader.readlines()
  175.  
  176. with io.open(os.path.join(folder, 'validation.txt'), 'r') as reader:
  177. train_files += reader.readlines()
  178.  
  179. with io.open(os.path.join(folder, 'test.txt'), 'r') as reader:
  180. test_files += reader.readlines()
  181.  
  182. for train_file in train_files:
  183. input_image, target_image = train_file.strip().split(' ')
  184. train_image = np.array(Image.open(input_image).convert('L')) # .convert('L')) -> grayscale (1-channel)
  185. train_image = np.multiply(train_image, 1.0 / 255)
  186. self.train_inputs.append(train_image)
  187. self.train_targets.append(np.array(Image.open(target_image).convert('1')).astype(np.float32)) # .convert('1')) -> binary
  188.  
  189. for test_file in test_files:
  190. input_image, target_image = test_file.strip().split(' ')
  191. test_image = np.array(Image.open(input_image).convert('L'))
  192. test_image = np.multiply(test_image, 1.0 / 255)
  193. self.test_inputs.append(test_image)
  194. self.test_targets.append(np.array(Image.open(target_image).convert('1')).astype(np.float32))
  195.  
  196. self.pointer = 0
  197.  
  198. def num_batches_in_epoch(self):
  199. return int(math.floor(len(self.train_inputs) / self.batch_size))
  200.  
  201. def reset_batch_pointer(self):
  202. permutation = np.random.permutation(len(self.train_inputs))
  203. self.train_inputs = [self.train_inputs[i] for i in permutation]
  204. self.train_targets = [self.train_targets[i] for i in permutation]
  205.  
  206. self.pointer = 0
  207.  
  208. def next_batch(self):
  209. inputs = []
  210. targets = []
  211. # print(self.batch_size, self.pointer, self.train_inputs.shape, self.train_targets.shape)
  212. for i in range(self.batch_size):
  213. inputs.append(np.array(self.train_inputs[self.pointer + i]))
  214. targets.append(np.array(self.train_targets[self.pointer + i]))
  215.  
  216. self.pointer += self.batch_size
  217.  
  218. return np.array(inputs), np.array(targets)
  219.  
  220.  
  221. def test_mnist():
  222. """Test the convolutional autoencder using MNIST."""
  223. # %%
  224.  
  225. dataset = Dataset()
  226.  
  227. inputs, targets = dataset.next_batch()
  228. print(inputs.shape, targets.shape)
  229. # print(targets[0].astype(np.float32))
  230.  
  231. # Image.fromarray(targets[0] * 255).show()
  232.  
  233. # Image.fromarray(targets[0], 'RGB').show()
  234.  
  235. # load MNIST as before
  236. # mnist = input_data.read_data_sets('MNIST_data', one_hot=True)
  237. # mean_img = np.mean(dataset.train_inputs)
  238.  
  239. network = Network()
  240.  
  241. with tf.Session() as sess:
  242. sess.run(tf.initialize_all_variables())
  243.  
  244. # Fit all training data
  245. batch_size = 100
  246. n_epochs = 10
  247. for epoch_i in range(n_epochs):
  248. dataset.reset_batch_pointer()
  249.  
  250. for batch_i in range(dataset.num_batches_in_epoch()):
  251. start = time.time()
  252. batch_inputs, batch_targets = dataset.next_batch()
  253. batch_inputs = np.reshape(batch_inputs, (dataset.batch_size, network.IMAGE_HEIGHT, network.IMAGE_WIDTH, 1))
  254. batch_targets = np.reshape(batch_targets, (dataset.batch_size, network.IMAGE_HEIGHT, network.IMAGE_WIDTH, 1))
  255. # train = np.array([img - mean_img for img in batch_inputs]).reshape((dataset.batch_size, network.IMAGE_HEIGHT, network.IMAGE_WIDTH, network.IMAGE_CHANNELS))
  256. cost, _ = sess.run([network.cost, network.train_op], feed_dict={network.inputs: batch_inputs, network.targets: batch_targets})
  257. end = time.time()
  258. print('{}/{}, epoch: {}, cost: {}, batch time: {}'.format(epoch_i * dataset.num_batches_in_epoch() + batch_i,
  259. n_epochs * dataset.num_batches_in_epoch(),
  260. epoch_i, cost, end - start))
  261.  
  262. # Plot example reconstructions
  263. n_examples = 10
  264. test_inputs, test_targets = dataset.test_inputs[:n_examples], dataset.test_targets[:n_examples]
  265.  
  266. test_segmentation = sess.run(network.segmentation_result, feed_dict={network.inputs: np.reshape(test_inputs, [10, 250, 250, 1])})
  267. fig, axs = plt.subplots(4, n_examples, figsize=(10, 2))
  268. for example_i in range(n_examples):
  269. axs[0][example_i].imshow(test_inputs[example_i], cmap='gray')
  270. axs[1][example_i].imshow(test_targets[example_i], cmap='gray')
  271. axs[2][example_i].imshow(test_segmentation[example_i], cmap='gray')
  272.  
  273. test_image_thresholded = np.array([0 if x < 0.5 else 1 for x in test_segmentation[example_i].flatten()])
  274. axs[3][example_i].imshow(np.reshape(test_image_thresholded, [network.IMAGE_HEIGHT, network.IMAGE_WIDTH]), cmap='gray')
  275. # fig.show()
  276. # plt.draw()
  277.  
  278. plt.savefig('figure{}.jpg'.format(batch_i + epoch_i * dataset.num_batches_in_epoch()))
  279.  
  280.  
  281. if __name__ == '__main__':
  282. test_mnist()
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement