Guest User

Untitled

a guest
Jan 22nd, 2018
58
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
text 3.70 KB | None | 0 0
  1. import cv2
  2. import numpy as np
  3. import os
  4. from random import shuffle
  5. from tqdm import tqdm
  6. import tensorflow as tf
  7. import tflearn
  8. import itertools
  9. from tflearn.layers.conv import conv_2d, max_pool_2d
  10. from tflearn.layers.core import input_data, dropout, fully_connected
  11. from tflearn.layers.estimator import regression
  12.  
  13. TRAIN_DIR = '/Users/xyz/Desktop/train'
  14. TEST_DIR = '/Users/xyz/Desktop/test'
  15. IMG_SIZE = 50
  16. LR = 1e-3
  17.  
  18. MODEL_NAME = 'abcvsdef-{}-{}.model'.format(LR, '2conv-basic') # just so we remember which saved model is which, sizes must match
  19.  
  20. def label_img(img):
  21. word_label = img.split('.')[-3]
  22. # conversion to one-hot array [cat,dog]
  23. if word_label == 'm':
  24. return [1,0]
  25. elif word_label == 'n':
  26. return [0,1]
  27.  
  28. # process the training imnages and their labels into arrays
  29.  
  30. def create_train_data():
  31. training_data = []
  32. #for img in tqdm(os.listdir(TRAIN_DIR)):
  33. for root, dirs, files in os.walk(TRAIN_DIR):
  34. for file1, file2 in itertools.izip_longest(files[::2], files[1::2]):
  35. img1 = cv2.imread(root + '/' + file1)
  36. img2 = cv2.imread(root + '/' + file2)
  37. img1 = cv2.resize(img1,(IMG_SIZE,IMG_SIZE))
  38. img2 = cv2.resize(img2,(IMG_SIZE,IMG_SIZE))
  39. image_pairs = np.concatenate((img1, img2), axis=1)
  40. label = label_img(file1)
  41. training_data.append([np.array(image_pairs),np.array(label)])
  42.  
  43. return training_data
  44.  
  45. # train the data
  46. train_data = create_train_data()
  47.  
  48. # construct the CNN
  49. convnet = input_data(shape=[None, IMG_SIZE, IMG_SIZE, 1], name='input')
  50.  
  51. convnet = conv_2d(convnet, 32, 5, activation='relu')
  52. convnet = max_pool_2d(convnet, 5)
  53.  
  54. convnet = conv_2d(convnet, 64, 5, activation='relu')
  55. convnet = max_pool_2d(convnet, 5)
  56.  
  57. convnet = conv_2d(convnet, 128, 5, activation='relu')
  58. convnet = max_pool_2d(convnet, 5)
  59.  
  60. convnet = conv_2d(convnet, 64, 5, activation='relu')
  61. convnet = max_pool_2d(convnet, 5)
  62.  
  63. convnet = conv_2d(convnet, 32, 5, activation='relu')
  64. convnet = max_pool_2d(convnet, 5)
  65.  
  66. convnet = fully_connected(convnet, 1024, activation='relu')
  67. convnet = dropout(convnet, 0.8)
  68.  
  69. convnet = fully_connected(convnet, 2, activation='softmax')
  70. convnet = regression(convnet, optimizer='adam', learning_rate=LR, loss='categorical_crossentropy', name='targets')
  71.  
  72. model = tflearn.DNN(convnet, tensorboard_dir='log')
  73.  
  74. train = train_data[:4]
  75. test = train_data[4:]
  76.  
  77. X = np.array([i[0] for i in train]).reshape(-1,IMG_SIZE,IMG_SIZE,1)
  78. Y = [i[1] for i in train]
  79.  
  80. test_x = np.array([i[0] for i in test]).reshape(-1,IMG_SIZE,IMG_SIZE,1)
  81. test_y = [i[1] for i in test]
  82.  
  83. model.fit({'input': X}, {'targets': Y}, n_epoch=10, validation_set=({'input': test_x}, {'targets': test_y}),
  84. snapshot_step=500, show_metric=True, run_id=MODEL_NAME)
  85.  
  86. model.save(MODEL_NAME)
  87.  
  88. ---------------------------------
  89. Run id: abcvsdef-0.001-2conv-basic.model
  90. Log directory: log/
  91. ---------------------------------
  92. Training samples: 24
  93. Validation samples: 24
  94. --
  95. Exception in thread Thread-3:
  96. Traceback (most recent call last):
  97. File "/System/Library/Frameworks/Python.framework/Versions/2.7/lib/python2.7/threading.py", line 810, in __bootstrap_inner
  98. self.run()
  99. File "/System/Library/Frameworks/Python.framework/Versions/2.7/lib/python2.7/threading.py", line 763, in run
  100. self.__target(*self.__args, **self.__kwargs)
  101. File "/Library/Python/2.7/site-packages/tflearn/data_flow.py", line 187, in fill_feed_dict_queue
  102. data = self.retrieve_data(batch_ids)
  103. File "/Library/Python/2.7/site-packages/tflearn/data_flow.py", line 222, in retrieve_data
  104. utils.slice_array(self.feed_dict[key], batch_ids)
  105. File "/Library/Python/2.7/site-packages/tflearn/utils.py", line 187, in slice_array
  106. return X[start]
  107. IndexError: index 19 is out of bounds for axis 0 with size 4
Add Comment
Please, Sign In to add comment