Advertisement
Guest User

Untitled

a guest
Jul 24th, 2017
73
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
text 8.72 KB | None | 0 0
  1. ```python
  2. from glob import glob
  3.  
  4. # 画像の検索
  5. image_files = glob('/home/imgs/*.png') # * はワイルドカード
  6. ```
  7.  
  8.  
  9. ```python
  10. import cv2
  11. ```
  12.  
  13.  
  14. ```python
  15. %matplotlib inline
  16. import matplotlib.pyplot as plt
  17. ```
  18.  
  19. # 画像の事前処理
  20.  
  21.  
  22. ```python
  23. import os
  24. HEIGHT, WIDTH = 50, 50 # 縮小後のサイズ
  25. for image_file in image_files:
  26. # OpenCVでの画像の読み込み
  27. img = cv2.imread(image_file)
  28. # 画像の縮小
  29. img_small = cv2.resize(img, (WIDTH, HEIGHT))
  30. # 縮小後のファイル名
  31. img_small_file = './' + os.path.basename(image_file)
  32. # 画像の保存
  33. cv2.imwrite(img_small_file, img_small)
  34. ```
  35.  
  36.  
  37. ```python
  38. import pandas as pd
  39.  
  40. # ラベルデータの読み込み
  41. df = pd.read_csv('data.csv')
  42. ```
  43.  
  44. # 入力変数と出力変数の切り分け
  45.  
  46.  
  47. ```python
  48. x, t = [], []
  49.  
  50. for (i, row) in df.iterrows():
  51. # 各行のデータからfilepathとlabelを取得
  52. filepath, label = row['filepath'], row['label']
  53.  
  54. # 画像の読み込み
  55. img = cv2.imread(filepath)
  56.  
  57. # RGBチャンネル(2)を一番前に持ってくる
  58. # + 0−1の範囲に正規化(255で割る)
  59. _x = np.transpose(img, (2, 0, 1)) / 255.0
  60.  
  61. # リストに追加
  62. x.append(_x) # 正規化した画像データ
  63. t.append(row['label']) # ラベル
  64. ```
  65.  
  66.  
  67. ```python
  68. import numpy as np
  69.  
  70. # numpyの形式に変換 + データ型の変更
  71. x = np.array(x, dtype=np.float32)
  72. t = np.array(t, dtype=np.int32)
  73. ```
  74.  
  75.  
  76. ```python
  77. from chainer.datasets import tuple_dataset, split_dataset_random
  78.  
  79. # Chainer推奨のdataset形式
  80. # dataset = list(zip(x, t))
  81. dataset = tuple_dataset.TupleDataset(x, t)
  82.  
  83. # 訓練データ(50%)と検証データ(50%)に分割
  84. n_train = int( len(dataset) * 0.5 )
  85. train, test = split_dataset_random(dataset, n_train, seed=1)
  86. ```
  87.  
  88. # モデル定義
  89.  
  90.  
  91. ```python
  92. import chainer
  93. from chainer import Chain, Variable
  94. import chainer.links as L
  95. import chainer.functions as F
  96. ```
  97.  
  98.  
  99. ```python
  100. class CNN(Chain):
  101. def __init__(self, n_units, n_output):
  102. super().__init__()
  103. with self.init_scope():
  104. self.conv1 = L.Convolution2D(in_channels=3, out_channels=12, ksize=3, stride=1)
  105. self.l1 = L.Linear(None, n_units)
  106. self.l2 = L.Linear(None, n_output)
  107.  
  108. def __call__(self, x):
  109. z = F.relu(self.conv1(x))
  110. h1 = F.max_pooling_2d(z, 3, 3)
  111. h2 = self.l1(h1)
  112. return self.l2(h2)
  113. ```
  114.  
  115.  
  116. ```python
  117. # モデルの宣言
  118. cnn = CNN(50,2)
  119. model = L.Classifier(cnn)
  120. ```
  121.  
  122.  
  123. ```python
  124. # Optimizerの設定
  125. optimizer = chainer.optimizers.Adam()
  126. optimizer.setup(model)
  127. ```
  128.  
  129.  
  130. ```python
  131. # Iteratorの設定
  132. batchsize = 3
  133. train_iter = chainer.iterators.SerialIterator(train, batchsize)
  134. test_iter = chainer.iterators.SerialIterator(test, batchsize, repeat=False, shuffle=False)
  135. ```
  136.  
  137.  
  138. ```python
  139. # Updateの設定
  140. from chainer import training
  141. updater = training.StandardUpdater(train_iter, optimizer)
  142. ```
  143.  
  144.  
  145. ```python
  146. # Trainerとそのextensionsの設定
  147. from chainer.training import extensions
  148.  
  149. epoch = 40
  150. trainer = training.Trainer(updater, (epoch, 'epoch'), out='result')
  151.  
  152. # 評価データで評価
  153. trainer.extend(extensions.Evaluator(test_iter, model))
  154.  
  155. # 学習結果の途中を表示する
  156. trainer.extend(extensions.LogReport(trigger=(1, 'epoch')))
  157.  
  158. # 1エポックごとに、trainデータに対するaccuracyと、testデータに対するaccuracyを出力させる
  159. trainer.extend(extensions.PrintReport(['epoch', 'main/accuracy', 'validation/main/accuracy', 'elapsed_time']), trigger=(1, 'epoch'))
  160. ```
  161.  
  162.  
  163. ```python
  164. trainer.run()
  165. ```
  166.  
  167. epoch main/accuracy validation/main/accuracy elapsed_time
  168. 1 0.583333 0.5 0.081398
  169. 2 0.666667 0.916667 0.127008
  170. 3 0.916667 0.75 0.172755
  171. 4 0.916667 0.916667 0.218207
  172. 5 1 0.916667 0.264898
  173. 6 1 0.916667 0.324137
  174. 7 1 0.833333 0.370681
  175. 8 1 0.916667 0.416805
  176. 9 1 0.916667 0.46221
  177. 10 1 0.833333 0.508246
  178. 11 1 0.833333 0.565425
  179. 12 1 0.916667 0.612338
  180. 13 1 0.916667 0.658242
  181. 14 1 0.916667 0.703609
  182. 15 1 0.916667 0.749522
  183. 16 1 0.916667 0.805929
  184. 17 1 0.916667 0.852826
  185. 18 1 0.916667 0.898924
  186. 19 1 0.916667 0.945089
  187. 20 1 0.916667 0.990832
  188. 21 1 0.916667 1.05026
  189. 22 1 0.916667 1.09726
  190. 23 1 0.916667 1.14336
  191. 24 1 0.916667 1.18942
  192. 25 1 0.916667 1.23655
  193. 26 1 0.916667 1.30444
  194. 27 1 0.916667 1.3509
  195. 28 1 0.916667 1.41909
  196. 29 1 0.916667 1.46566
  197. 30 1 0.916667 1.51978
  198. 31 1 0.916667 1.57249
  199. 32 1 0.916667 1.61922
  200. 33 1 0.916667 1.66567
  201. 34 1 0.916667 1.71288
  202. 35 1 0.916667 1.76941
  203. 36 1 0.916667 1.81707
  204. 37 1 0.916667 1.86409
  205. 38 1 0.916667 1.91094
  206. 39 1 0.916667 1.95727
  207. 40 1 0.916667 2.01688
  208.  
  209.  
  210. # 学習結果を確認
  211.  
  212.  
  213. ```python
  214. import json
  215. with open('result/log') as f:
  216. logs = json.load(f)
  217. ```
  218.  
  219.  
  220. ```python
  221. loss_train = [ log['main/loss'] for log in logs ]
  222. loss_test = [ log['validation/main/loss'] for log in logs ]
  223. ```
  224.  
  225.  
  226. ```python
  227. plt.plot(loss_train, label='train') # 訓練データ
  228. plt.plot(loss_test, label='test') # 検証データ
  229. plt.legend() # 凡例表示
  230. plt.show()
  231. ```
  232.  
  233.  
  234. ![png](output_22_0.png)
  235.  
  236.  
  237.  
  238. ```python
  239. accuracy_train = [ log['main/accuracy'] for log in logs ]
  240. accuracy_test = [ log['validation/main/accuracy'] for log in logs ]
  241. ```
  242.  
  243.  
  244. ```python
  245. plt.plot(accuracy_train, label='train') # 訓練データ
  246. plt.plot(accuracy_test, label='test') # 検証データ
  247. plt.legend() # 凡例表示
  248. plt.show()
  249. ```
  250.  
  251.  
  252. ![png](output_24_0.png)
  253.  
  254.  
  255. # 予測値の計算(推論)
  256.  
  257.  
  258. ```python
  259. for datum in test:
  260. _x, _t = datum
  261. # クラスの予測値を計算
  262. y = cnn(np.array([_x])) # cnnの中のコール関数が呼ばれている?()
  263. y = F.softmax(y).data # softmax関数で足して1になるようデータ調整
  264. index = np.argmax(y) # もっとも値が大きいものを取得
  265. # 結果の表示
  266. print('教師データ: ', _t, ' 予測値: ', index, '予測値リスト', y)
  267. ```
  268.  
  269. 教師データ: 1 予測値: 1 予測値リスト [[ 0.04602067 0.95397931]]
  270. 教師データ: 0 予測値: 0 予測値リスト [[ 0.82391012 0.17608985]]
  271. 教師データ: 1 予測値: 1 予測値リスト [[ 0.00875639 0.9912436 ]]
  272. 教師データ: 0 予測値: 0 予測値リスト [[ 0.8621484 0.13785164]]
  273. 教師データ: 1 予測値: 1 予測値リスト [[ 0.01890844 0.98109162]]
  274. 教師データ: 1 予測値: 1 予測値リスト [[ 0.00945963 0.99054039]]
  275. 教師データ: 1 予測値: 1 予測値リスト [[ 0.00305405 0.99694592]]
  276. 教師データ: 0 予測値: 1 予測値リスト [[ 0.38699657 0.61300349]]
  277. 教師データ: 0 予測値: 0 予測値リスト [[ 0.98359168 0.01640825]]
  278. 教師データ: 1 予測値: 1 予測値リスト [[ 0.10968747 0.89031249]]
  279. 教師データ: 0 予測値: 0 予測値リスト [[ 0.98453289 0.01546716]]
  280. 教師データ: 0 予測値: 0 予測値リスト [[ 0.97657734 0.0234226 ]]
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement