MIzadpanah

Untitled

Apr 23rd, 2017
81
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
text 5.74 KB | None | 0 0
  1.  
  2. با سلام...من برنامه زیر رو در تنسور فلو نوشتم:
  3.  
  4. import tensorflow as tf
  5. import numpy as np
  6. import matplotlib.pyplot as plt
  7. try:
  8. from scipy import misc
  9. except ImportError:
  10. !pip install scipy
  11. from scipy import misc
  12.  
  13. training_size = 9
  14. img_size = 20*20*3
  15. training_data = np.empty(shape=(training_size, img_size))
  16.  
  17. import glob
  18. i = 0
  19. for filename in glob.glob('D:/Minutia/*.jpg'):
  20. image = misc.imread(filename)
  21. training_data[i] = image.reshape(-1)
  22. i+=1
  23. print(training_data[0].shape)
  24.  
  25. a = [0, 0, 0,1,1,1,2,2,2]
  26. b = tf.one_hot(a,3)
  27. sess = tf.Session()
  28. sess.run(b)
  29. from __future__ import division, print_function, absolute_import
  30.  
  31. import tflearn
  32. from tflearn.layers.core import input_data, dropout, fully_connected
  33. from tflearn.layers.conv import conv_2d, max_pool_2d
  34. from tflearn.layers.normalization import local_response_normalization
  35. from tflearn.layers.estimator import regression
  36.  
  37. network = input_data(shape=[None, 227, 227, 3])
  38. network = conv_2d(network, 96, 11, strides=4, activation='relu')
  39. network = max_pool_2d(network, 3, strides=2)
  40. network = local_response_normalization(network)
  41. network = conv_2d(network, 256, 5, activation='relu')
  42. network = max_pool_2d(network, 3, strides=2)
  43. network = local_response_normalization(network)
  44. network = conv_2d(network, 384, 3, activation='relu')
  45. network = conv_2d(network, 384, 3, activation='relu')
  46. network = conv_2d(network, 256, 3, activation='relu')
  47. network = max_pool_2d(network, 3, strides=2)
  48. network = local_response_normalization(network)
  49. network = fully_connected(network, 4096, activation='tanh')
  50. network = dropout(network, 0.5)
  51. network = fully_connected(network, 4096, activation='tanh')
  52. network = dropout(network, 0.5)
  53. network = fully_connected(network, 17, activation='softmax')
  54. network = regression(network, optimizer='momentum',
  55. loss='categorical_crossentropy',
  56. learning_rate=0.001)
  57. model = tflearn.DNN(network, checkpoint_path='model_alexnet',
  58. max_checkpoints=1, tensorboard_verbose=2)
  59. model.fit(training_data, b, n_epoch=1000, validation_set=0.1, shuffle=True,
  60. show_metric=True, batch_size=64, snapshot_step=200,
  61. snapshot_epoch=False, run_id='alexnet_oxflowers17')
  62.  
  63.  
  64.  
  65. که هنگام اجرای
  66.  
  67. model.fit(training_data, b, n_epoch=1000, validation_set=0.1, shuffle=True,
  68. show_metric=True, batch_size=64, snapshot_step=200,
  69. snapshot_epoch=False, run_id='alexnet_oxflowers17')
  70. با خطای زیر مواجه شدم:
  71.  
  72. Exception in thread Thread-8:
  73. Traceback (most recent call last):
  74. File "C:\ProgramData\Anaconda3\envs\tensorflow\lib\site-packages\tensorflow\python\framework\common_shapes.py", line 670, in _call_cpp_shape_fn_impl
  75. status)
  76. File "C:\ProgramData\Anaconda3\envs\tensorflow\lib\contextlib.py", line 66, in __exit__
  77. next(self.gen)
  78. File "C:\ProgramData\Anaconda3\envs\tensorflow\lib\site-packages\tensorflow\python\framework\errors_impl.py", line 469, in raise_exception_on_not_ok_status
  79. pywrap_tensorflow.TF_GetCode(status))
  80. tensorflow.python.framework.errors_impl.InvalidArgumentError: Shape must be rank 1 but is rank 2 for 'strided_slice' (op: 'StridedSlice') with input shapes: [9,3], [1,8], [1,8], [1].
  81.  
  82. During handling of the above exception, another exception occurred:
  83.  
  84. Traceback (most recent call last):
  85. File "C:\ProgramData\Anaconda3\envs\tensorflow\lib\threading.py", line 914, in _bootstrap_inner
  86. self.run()
  87. File "C:\ProgramData\Anaconda3\envs\tensorflow\lib\threading.py", line 862, in run
  88. self._target(*self._args, **self._kwargs)
  89. File "C:\ProgramData\Anaconda3\envs\tensorflow\lib\site-packages\tflearn\data_flow.py", line 187, in fill_feed_dict_queue
  90. data = self.retrieve_data(batch_ids)
  91. File "C:\ProgramData\Anaconda3\envs\tensorflow\lib\site-packages\tflearn\data_flow.py", line 222, in retrieve_data
  92. utils.slice_array(self.feed_dict[key], batch_ids)
  93. File "C:\ProgramData\Anaconda3\envs\tensorflow\lib\site-packages\tflearn\utils.py", line 187, in slice_array
  94. return X[start]
  95. File "C:\ProgramData\Anaconda3\envs\tensorflow\lib\site-packages\tensorflow\python\ops\array_ops.py", line 513, in _SliceHelper
  96. name=name)
  97. File "C:\ProgramData\Anaconda3\envs\tensorflow\lib\site-packages\tensorflow\python\ops\array_ops.py", line 671, in strided_slice
  98. shrink_axis_mask=shrink_axis_mask)
  99. File "C:\ProgramData\Anaconda3\envs\tensorflow\lib\site-packages\tensorflow\python\ops\gen_array_ops.py", line 3688, in strided_slice
  100. shrink_axis_mask=shrink_axis_mask, name=name)
  101. File "C:\ProgramData\Anaconda3\envs\tensorflow\lib\site-packages\tensorflow\python\framework\op_def_library.py", line 763, in apply_op
  102. op_def=op_def)
  103. File "C:\ProgramData\Anaconda3\envs\tensorflow\lib\site-packages\tensorflow\python\framework\ops.py", line 2397, in create_op
  104. set_shapes_for_outputs(ret)
  105. File "C:\ProgramData\Anaconda3\envs\tensorflow\lib\site-packages\tensorflow\python\framework\ops.py", line 1757, in set_shapes_for_outputs
  106. shapes = shape_func(op)
  107. File "C:\ProgramData\Anaconda3\envs\tensorflow\lib\site-packages\tensorflow\python\framework\ops.py", line 1707, in call_with_requiring
  108. return call_cpp_shape_fn(op, require_shape_fn=True)
  109. File "C:\ProgramData\Anaconda3\envs\tensorflow\lib\site-packages\tensorflow\python\framework\common_shapes.py", line 610, in call_cpp_shape_fn
  110. debug_python_shape_fn, require_shape_fn)
  111. File "C:\ProgramData\Anaconda3\envs\tensorflow\lib\site-packages\tensorflow\python\framework\common_shapes.py", line 675, in _call_cpp_shape_fn_impl
  112. raise ValueError(err.message)
  113. ValueError: Shape must be rank 1 but is rank 2 for 'strided_slice' (op: 'StridedSlice') with input shapes: [9,3], [1,8], [1,8], [1].
Add Comment
Please, Sign In to add comment