Advertisement
Guest User

Untitled

a guest
Oct 17th, 2017
135
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
text 3.61 KB | None | 0 0
  1. import numpy as np
  2. import matplotlib.pyplot as plt
  3. try:
  4. from scipy import misc
  5. except ImportError:
  6. !pip install scipy
  7. from scipy import misc
  8.  
  9. training_size = 300
  10. img_size = 20*20*3
  11. training_data = np.empty(shape=(training_size,20,20,3))
  12.  
  13. import glob
  14. i = 0
  15. for filename in glob.glob('D:/Minutia/PrincipleWrinkleMinutia/*.jpg'):
  16. image = misc.imread(filename)
  17. training_data[i] = image
  18. i+=1
  19. print(training_data[0].shape)
  20.  
  21. a= [0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,
  22. 1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,
  23. 2,2,2,2,2,2,2,2,2,2,2,2,2,2,2,2,2,2,2,2,2,2,2,2,2,2,2,2,2,2,2,2,2,2,2,2,2,2,2,2,2,2,2,2,2,2,2,2,2,2,2,2,2,2,2,2,2,2,2,2,2,2,2,2,2,2,2,2,2,2,2,2,2,2,2,2,2,2,2,2,2,2,2,2,2,2,2,2,2,2,2,2,2,2,2,2,2,2,2,2]
  24. from sklearn.preprocessing import OneHotEncoder
  25. a = np.asarray(a)
  26. b = OneHotEncoder(sparse=False).fit_transform(a.reshape(-1, 1))
  27. #b = tf.one_hot(a,3)
  28. #sess = tf.Session()
  29. #sess.run(b)
  30. import tensorflow as tf
  31. tf.reset_default_graph()
  32. from __future__ import division, print_function, absolute_import
  33.  
  34. import tflearn
  35. from tflearn.layers.core import input_data, dropout, fully_connected
  36. from tflearn.layers.conv import conv_2d, max_pool_2d
  37. from tflearn.layers.normalization import local_response_normalization
  38. from tflearn.layers.estimator import regression
  39.  
  40. network = input_data(shape=[None, 20, 20, 3])
  41. network = conv_2d(network, 96, 11, strides=4, activation='relu')
  42. network = max_pool_2d(network, 3, strides=2)
  43. network = local_response_normalization(network)
  44. network = fully_connected(network, 4096, activation='tanh')
  45. network = dropout(network, 0.5)
  46. network = fully_connected(network, 3, activation='softmax')
  47. from __future__ import division, print_function, absolute_import
  48.  
  49. import tflearn
  50. from tflearn.layers.core import input_data, dropout, fully_connected
  51. from tflearn.layers.conv import conv_2d, max_pool_2d
  52. from tflearn.layers.normalization import local_response_normalization
  53. from tflearn.layers.estimator import regression
  54.  
  55. network = input_data(shape=[None, 20, 20, 3])
  56. network = conv_2d(network, 96, 11, strides=4, activation='relu')
  57. network = max_pool_2d(network, 3, strides=2)
  58. network = local_response_normalization(network)
  59. network = conv_2d(network, 256, 5, activation='relu')
  60. network = max_pool_2d(network, 3, strides=2)
  61. network = local_response_normalization(network)
  62. network = conv_2d(network, 384, 3, activation='relu')
  63. network = conv_2d(network, 384, 3, activation='relu')
  64. network = conv_2d(network, 256, 3, activation='relu')
  65. network = max_pool_2d(network, 3, strides=2)
  66. network = local_response_normalization(network)
  67. network = fully_connected(network, 4096, activation='tanh')
  68. network = dropout(network, 0.5)
  69. network = fully_connected(network, 4096, activation='tanh')
  70. network = dropout(network, 0.5)
  71. network = fully_connected(network, 3, activation='softmax')
  72. network = regression(network, optimizer='momentum',
  73. loss='categorical_crossentropy',
  74. learning_rate=0.001)
  75. model = tflearn.DNN(network, checkpoint_path='model_alexnet',
  76. max_checkpoints=1, tensorboard_verbose=2)
  77. model.fit(training_data, a, n_epoch=1000,validation_set=0.1, shuffle=True,
  78. show_metric=True, batch_size=64, snapshot_step=200,
  79. snapshot_epoch=False, run_id='alexnet_oxflowers17')
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement