Guest User

Untitled

a guest
Jan 18th, 2019
92
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
text 4.37 KB | None | 0 0
  1. # -*- coding:utf-8 -*-
  2.  
  3. import keras
  4. from keras.applications.vgg19 import VGG19
  5. from keras.applications.densenet import decode_predictions
  6. from PIL import Image
  7. import numpy as np
  8. import os.path
  9. from DenseNet.DenseNet_ccc import densenet
  10. from keras.models import Model
  11. from keras.layers import Conv2D, Dense, Input, add, Activation, AveragePooling2D, GlobalAveragePooling2D, Lambda, \
  12. concatenate
  13. from keras import optimizers, regularizers
  14. from matplotlib import pyplot as plt
  15. import cv2
  16. import pickle
  17. import copy
  18.  
  19.  
  20. def count_paint(path): # 用于提示汉字的剪切
  21. img = cv2.imread(path) # 读取图片,装换为可运算的数组
  22. img = img[344:384, 0:344]
  23. GrayImage = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY) # 将BGR图转为灰度图
  24. ret, thresh1 = cv2.threshold(GrayImage, 160, 255, cv2.THRESH_BINARY) # 将图片进行二值化(130,255)之间的点均变为255(背景)
  25.  
  26. (h, w) = thresh1.shape # 返回高和宽
  27. a = [0 for z in range(0, w)]
  28. # # 记录每一列的波峰
  29. for j in range(0, w): # 遍历一列
  30. for i in range(0, h): # 遍历一行
  31. if thresh1[i, j] == 0: # 如果改点为黑点
  32. a[j] += 1 # 该列的计数器加一计数
  33.  
  34. b = []
  35. for j in range(0, w):
  36. if a[j] == 0:
  37. b.append(j)
  38.  
  39. c = []
  40. for i in range(len(b)-1):
  41. if b[i+1] - b[i] >= 3:
  42. c.append(b[i])
  43. c.append(b[i+1])
  44.  
  45. d = []
  46. for i in range(1, len(c) - 1, 2):
  47. if c[i + 1] - c[i] < 2:
  48. d.append(i)
  49. if len(d) != 0:
  50. d.reverse()
  51. for i in range(len(d)):
  52. if c[d[i] + 2] - c[d[i] - 1] <= 40:
  53. c.pop(d[i] + 1)
  54. c.pop(d[i])
  55.  
  56. e = []
  57. for i in range(0, len(c), 2):
  58. if c[i + 1] - c[i] < 10:
  59. left = c[i + 1] - c[i - 2] if i - 2 >= 0 else 40
  60. right = c[i + 3] - c[i] if i + 3 <= len(c) else 40
  61. if left < right:
  62. e.append(i - 1)
  63. e.append(i)
  64. else:
  65. e.append(i + 1)
  66. e.append(i + 2)
  67. if len(e) != 0:
  68. e.reverse()
  69. for i in range(len(e)):
  70. c.pop(e[i])
  71.  
  72. imgs = []
  73. for i in range(0,len(c),2):
  74. sss = thresh1[0:40, c[i]:c[i+1]]
  75. size = sss.shape
  76. if size[1] < 40:
  77. padding = 40-size[1]
  78. left = padding//2
  79. sss = cv2.copyMakeBorder(sss, 0, 0, left, padding-left, cv2.BORDER_CONSTANT, value=[255,255,255])
  80. imgs.append(sss)
  81. else:
  82. print(size,'++++++++')
  83. plt.imshow(sss, cmap=plt.gray())
  84. plt.show()
  85. imgs.append(sss)
  86. return imgs, thresh1
  87.  
  88. def get_label_dict():
  89. f=open('./chinese_labels','rb')
  90. label_dict = pickle.load(f)
  91. f.close()
  92. return label_dict
  93.  
  94. img_rows, img_cols = 40, 40
  95. img_channels = 1
  96. num_classes = 3755
  97. img_input = Input(shape=(img_rows, img_cols, img_channels))
  98. output = densenet(img_input, num_classes)
  99. model = Model(img_input, output)
  100. model.load_weights('model-ep008-loss0.615-val_loss0.633.h5')
  101. sgd = optimizers.SGD(lr=.1, momentum=0.9, nesterov=True)
  102. model.compile(loss='categorical_crossentropy', optimizer=sgd, metrics=['accuracy'])
  103. label_dict = get_label_dict()
  104. path = '/Users/haodexian/PycharmProjects/scrapy-async-selenium-develop/reviews/img/'
  105. img_generator = (os.path.join(path, file) for file in os.listdir(path))
  106.  
  107. while True:
  108. text = input('input "n" ( input Q to exit ): ')
  109. if text == 'Q' :
  110. break
  111. if text == 'n' :
  112. img_path = next(img_generator)
  113. print(img_path)
  114. if not os.path.exists(img_path):
  115. print("file not exist!")
  116. continue
  117. try:
  118. images, ori = count_paint(img_path)
  119. plt.imshow(ori[:,0:150], cmap=plt.gray())
  120. plt.show()
  121. for i in range(len(images)):
  122. img = cv2.resize(images[i], (40, 40))
  123. img = (img.reshape(1, 40, 40, 1)).astype('int32')/255 #例子的个数,图片的 通道数,图片的长与宽
  124. results = model.predict(img)[0]
  125. # predict = np.argmax(results, axis=1)
  126. # print(label_dict[predict[0]])
  127. top_indices = results.argsort()[-20:][::-1]
  128. result = []
  129. for i in top_indices:
  130. if results[i] * 100 >= 1:
  131. result.append((label_dict[i],int(results[i] * 100)))
  132. print(result)
  133. except Exception as e:
  134. raise e
Add Comment
Please, Sign In to add comment