Advertisement
Guest User

Untitled

a guest
Aug 23rd, 2017
55
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
text 1.78 KB | None | 0 0
  1. import tensorflow as tf
  2. import matplotlib.pyplot as plt
  3. import os
  4. %matplotlib inline
  5.  
  6. images = "dataset/test_dataset_png/"
  7. image_dir = os.path.join(os.getcwd(), images)
  8. imagenames = [os.path.join(image_dir, f) for f in os.listdir(image_dir)]
  9.  
  10. label = "dataset/test_dataset_csv/label.csv"
  11. labelname = [os.path.join(os.getcwd(), label)]
  12.  
  13. imagename_queue = tf.train.string_input_producer(imagenames)
  14. labelname_queue = tf.train.string_input_producer(labelname)
  15.  
  16. img_reader = tf.WholeFileReader()
  17. label_reader = tf.TextLineReader()
  18.  
  19. _, image = img_reader.read(imagename_queue)
  20. _, label = label_reader.read(labelname_queue)
  21.  
  22. decoded_img = tf.image.decode_png(image)
  23. reshaped_img = tf.reshape(decoded_img, shape=[61, 49, 1])
  24. reshaped_img = tf.cast(reshaped_img, tf.float32)
  25.  
  26. decoded_label = tf.decode_csv(label, record_defaults=[[0]])
  27.  
  28. x, y_ = tf.train.batch([reshaped_img, decoded_label], 10)
  29.  
  30. conv1 = tf.layers.conv2d(x, filters=10, kernel_size=[3, 3], padding="SAME")
  31. conv2 = tf.layers.conv2d(conv1, filters=10, kernel_size=[3, 3], padding="SAME")
  32. # pool2 = tf.layers.max_pooling2d(conv2, pool_size=[2, 2], strides=[2, 2])
  33.  
  34. conv3 = tf.layers.conv2d(conv2, filters=10, kernel_size=[3, 3], padding="SAME")
  35. # pool3 = tf.layers.max_pooling2d(conv3, pool_size=[2, 2], strides=[2, 2])
  36.  
  37. conv4 = tf.layers.conv2d(conv3, filters=10, kernel_size=[3, 3], padding="SAME")
  38. # pool4 = tf.layers.max_pooling2d(conv4, pool_size=[2, 2], strides=[2, 2])
  39.  
  40. flat = tf.reshape(conv4, shape=[-1, 61*49*10])
  41.  
  42. fc1 = tf.layers.dense(flat, 5000)
  43. fc2 = tf.layers.dense(fc1, 1000)
  44. out = tf.layers.dense(fc2, 3)
  45.  
  46. with tf.Session() as sess:
  47. coord = tf.train.Coordinator()
  48. thread = tf.train.start_queue_runners(sess, coord)
  49. for i in range(100):
  50. age = sess.run(decoded_label)
  51. print(age)
  52. coord.request_stop()
  53. coord.join(thread)
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement