Guest User

Untitled

a guest
Jan 20th, 2017
87
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
text 4.88 KB | None | 0 0
  1. # pylint: disable=C0111,too-many-arguments,too-many-instance-attributes,too-many-locals,redefined-outer-name,fixme
  2. # pylint: disable=superfluous-parens, no-member, invalid-name
  3. import sys
  4. sys.path.insert(0, "../../python")
  5. import mxnet as mx
  6. import numpy as np
  7. import cv2, random
  8.  
  9. from io import BytesIO
  10. from captcha.image import ImageCaptcha
  11.  
  12. class OCRBatch(object):
  13. def __init__(self, data_names, data, label_names, label):
  14. self.data = data
  15. self.label = label
  16. self.data_names = data_names
  17. self.label_names = label_names
  18.  
  19. @property
  20. def provide_data(self):
  21. return [(n, x.shape) for n, x in zip(self.data_names, self.data)]
  22.  
  23. @property
  24. def provide_label(self):
  25. return [(n, x.shape) for n, x in zip(self.label_names, self.label)]
  26.  
  27. def gen_rand():
  28. num = random.randint(0, 9999)
  29. buf = str(num)
  30. while len(buf) < 4:
  31. buf = "0" + buf
  32. return buf
  33.  
  34. def get_label(buf):
  35. return np.array([int(x) for x in buf])
  36.  
  37. class OCRIter(mx.io.DataIter):
  38. def __init__(self, count, batch_size, num_label, height, width):
  39. super(OCRIter, self).__init__()
  40. self.captcha = ImageCaptcha(fonts=['./data/OpenSans-Regular.ttf'])
  41. self.batch_size = batch_size
  42. self.count = count
  43. self.height = height
  44. self.width = width
  45. self.provide_data = [('data', (batch_size, 3, height, width))]
  46. self.provide_label = [('softmax_label', (self.batch_size, num_label))]
  47.  
  48. def __iter__(self):
  49. for k in range(self.count / self.batch_size):
  50. data = []
  51. label = []
  52. for i in range(self.batch_size):
  53. num = gen_rand()
  54. img = self.captcha.generate(num)
  55. img = np.fromstring(img.getvalue(), dtype='uint8')
  56. img = cv2.imdecode(img, cv2.IMREAD_COLOR)
  57. img = cv2.resize(img, (self.width, self.height))
  58. cv2.imwrite("./tmp" + str(i % 10) + ".png", img)
  59. img = np.multiply(img, 1/255.0)
  60. img = img.transpose(2, 0, 1)
  61. data.append(img)
  62. label.append(get_label(num))
  63.  
  64. data_all = [mx.nd.array(data)]
  65. label_all = [mx.nd.array(label)]
  66. data_names = ['data']
  67. label_names = ['softmax_label']
  68.  
  69. data_batch = OCRBatch(data_names, data_all, label_names, label_all)
  70. yield data_batch
  71.  
  72. def reset(self):
  73. pass
  74.  
  75. def get_ocrnet():
  76. data = mx.symbol.Variable('data')
  77. label = mx.symbol.Variable('softmax_label')
  78. conv1 = mx.symbol.Convolution(data=data, kernel=(5,5), num_filter=32)
  79. pool1 = mx.symbol.Pooling(data=conv1, pool_type="max", kernel=(2,2), stride=(1, 1))
  80. relu1 = mx.symbol.Activation(data=pool1, act_type="relu")
  81.  
  82. conv2 = mx.symbol.Convolution(data=relu1, kernel=(5,5), num_filter=32)
  83. pool2 = mx.symbol.Pooling(data=conv2, pool_type="avg", kernel=(2,2), stride=(1, 1))
  84. relu2 = mx.symbol.Activation(data=pool2, act_type="relu")
  85.  
  86. conv3 = mx.symbol.Convolution(data=relu2, kernel=(3,3), num_filter=32)
  87. pool3 = mx.symbol.Pooling(data=conv3, pool_type="avg", kernel=(2,2), stride=(1, 1))
  88. relu3 = mx.symbol.Activation(data=pool3, act_type="relu")
  89.  
  90. flatten = mx.symbol.Flatten(data = relu3)
  91. fc1 = mx.symbol.FullyConnected(data = flatten, num_hidden = 512)
  92. fc21 = mx.symbol.FullyConnected(data = fc1, num_hidden = 10)
  93. fc22 = mx.symbol.FullyConnected(data = fc1, num_hidden = 10)
  94. fc23 = mx.symbol.FullyConnected(data = fc1, num_hidden = 10)
  95. fc24 = mx.symbol.FullyConnected(data = fc1, num_hidden = 10)
  96. fc2 = mx.symbol.Concat(*[fc21, fc22, fc23, fc24], dim = 0)
  97. label = mx.symbol.transpose(data = label)
  98. label = mx.symbol.Reshape(data = label, target_shape = (0, ))
  99. return mx.symbol.SoftmaxOutput(data = fc2, label = label, name = "softmax")
  100.  
  101.  
  102. def Accuracy(label, pred):
  103. label = label.T.reshape((-1, ))
  104. hit = 0
  105. total = 0
  106. for i in range(pred.shape[0] / 4):
  107. ok = True
  108. for j in range(4):
  109. k = i * 4 + j
  110. if np.argmax(pred[k]) != int(label[k]):
  111. ok = False
  112. break
  113. if ok:
  114. hit += 1
  115. total += 1
  116. return 1.0 * hit / total
  117.  
  118. network = get_ocrnet()
  119. devs = [mx.gpu(0)]
  120. model = mx.model.FeedForward(ctx = devs,
  121. symbol = network,
  122. num_epoch = 15,
  123. learning_rate = 0.001,
  124. wd = 0.00001,
  125. initializer = mx.init.Xavier(factor_type="in", magnitude=2.34),
  126. momentum = 0.9)
  127.  
  128. data_train = OCRIter(100000, 50, 4, 30, 80)
  129. data_test = OCRIter(1000, 50, 4, 30, 80)
  130.  
  131. import logging
  132. head = '%(asctime)-15s %(message)s'
  133. logging.basicConfig(level=logging.DEBUG, format=head)
  134.  
  135. model.fit(X = data_train, eval_data = data_test, eval_metric = Accuracy, batch_end_callback=mx.callback.Speedometer(32, 50),)
Add Comment
Please, Sign In to add comment