Guest User

Untitled

a guest
Jun 22nd, 2018
90
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
text 15.06 KB | None | 0 0
  1. import argparse
  2. import collections
  3. import itertools
  4. import os
  5. import sys
  6.  
  7. import time
  8.  
  9. try:
  10. import cPickle as pickle
  11. except ImportError:
  12. import pickle
  13.  
  14. import PIL.Image
  15. import numpy as np
  16. import vlfeat
  17. import scipy.cluster.vq
  18. import scipy.spatial.distance
  19. import matplotlib.pyplot as plt
  20.  
  21. PAGES_PATH = os.path.join('data', 'pages')
  22. GT_PATH = os.path.join('data', 'GT')
  23. IFS_MATCH_IMAGES_PATH = os.path.join('ifs_match_images')
  24. MATCH_IMAGES_PATH = os.path.join('match_images')
  25. CODE_BOOK_PATH = os.path.join('data', 'codebook.bin')
  26. SPATIAL_PYRAMID_TYPES = ['L', 'R', 'G', 'GL', 'GR', 'LR', 'GLR']
  27. CELL_MARGIN_TYPES = ['none', 'horizontal', 'vertical', 'both']
  28.  
  29. argument_parser = argparse.ArgumentParser()
  30. argument_parser.add_argument('--step-size', '-s', type=int, default=15)
  31. argument_parser.add_argument('--cell-size', '-c', type=int, default=3)
  32. argument_parser.add_argument('--centroids', '-C', type=int, default=40)
  33. argument_parser.add_argument('--k-means-iterations', '-k', type=int, default=20)
  34. argument_parser.add_argument('--distance-metric', choices=['cityblock', 'cosine', 'euclidean'], default='cosine')
  35. argument_parser.add_argument('--spatial-pyramid-type', '-S', choices=SPATIAL_PYRAMID_TYPES, default='LR')
  36. argument_parser.add_argument('--pages', '-p', type=int, default=1)
  37. argument_parser.add_argument('--accumulator-percentile', '-a', type=float, default=95.0)
  38. argument_parser.add_argument('--use-ifs', '-I', action='store_true')
  39. argument_parser.add_argument('--use-accumulator', '-A', action='store_true')
  40. argument_parser.add_argument('--save-images', '-sa', action='store_true')
  41. argument_parser.add_argument('--verbose', action='store_true')
  42. argument_parser.add_argument('--cell-margin', choices=CELL_MARGIN_TYPES, default='horizontal')
  43.  
  44. SpatialPyramid = collections.namedtuple('SpatialPyramid', ['global_', 'left', 'right'])
  45.  
  46.  
  47. def makedirs(name, mode=0777, exist_ok=False):
  48. if not exist_ok:
  49. return os.makedirs(name, mode)
  50.  
  51. # Taken from Python 3
  52. try:
  53. os.makedirs(name, mode)
  54. except OSError:
  55. # Cannot rely on checking for EEXIST, since the operating system
  56. # could give priority to other errors like EACCES or EROFS
  57. if not exist_ok or not os.path.isdir(name):
  58. raise
  59.  
  60.  
  61. def load_gtp_file(path):
  62. entries = collections.defaultdict(list)
  63. with open(path) as file:
  64. for line in (line for line in (line.strip() for line in file) if line):
  65. x1, y1, x2, y2, word = line.split()
  66. entries[word].append((int(x1), int(y1), int(x2), int(y2)))
  67.  
  68. return entries
  69.  
  70.  
  71. def load_codebook(path):
  72. input_file = open(path, 'r')
  73. code_book = np.fromfile(input_file, dtype='float32')
  74. code_book = np.reshape(code_book, (4096, 128))
  75. return code_book
  76.  
  77.  
  78. def make_spatial_pyramid(data, length, type_='GLR'):
  79. count = len(data)
  80. left_index = int(np.floor(count / 2))
  81. right_index = int(np.ceil(count / 2))
  82.  
  83. if type_ == 'L':
  84. data = [], data[:left_index], []
  85. elif type_ == 'R':
  86. data = [], [], data[right_index:]
  87. elif type_ == 'G':
  88. data = data, [], []
  89. elif type_ == 'GL':
  90. data = data, data[:left_index], []
  91. elif type_ == 'GR':
  92. data = data, [], data[right_index:]
  93. elif type_ == 'LR':
  94. data = [], data[:left_index], data[right_index:]
  95. elif type_ == 'GLR':
  96. data = data, data[:left_index], data[right_index:]
  97. else:
  98. raise ValueError('unknown spatial pyramid type: %r' % type_)
  99.  
  100. spatial_pyramid = SpatialPyramid(*(np.bincount(datum, minlength=length) for datum in data))
  101. return np.concatenate(spatial_pyramid)
  102.  
  103.  
  104. def load_corpus(page_names):
  105. defaultdict_factory = lambda: collections.defaultdict(defaultdict_factory)
  106. corpus = collections.defaultdict(defaultdict_factory)
  107.  
  108. offset = 0
  109. images = []
  110. corpus_gtp = collections.defaultdict(list)
  111. for page_name in page_names:
  112. corpus['pages'][page_name]['offset'] = offset
  113.  
  114. # Load page image
  115. image_path = os.path.join(PAGES_PATH, '%s.png' % page_name)
  116. corpus['pages'][page_name]['image_path'] = image_path
  117. image = PIL.Image.open(image_path)
  118. corpus['pages'][page_name]['image'] = image
  119. images.append(image)
  120.  
  121. # Load page GTP
  122. gtp_path = os.path.join(GT_PATH, '%s.gtp' % page_name)
  123. corpus['pages'][page_name]['gtp_path'] = gtp_path
  124. gtp = load_gtp_file(gtp_path)
  125. corpus['pages'][page_name]['gtp'] = gtp
  126.  
  127. # Update global corpus GTP with current offset
  128. for word, coordinates in gtp.items():
  129. for x1, y1, x2, y2 in coordinates:
  130. corpus_gtp[word].append((x1 + offset, y1, x2 + offset, y2))
  131.  
  132. offset += image.width
  133.  
  134. # Create Corpus image by concatenating page images horizontally
  135. width = sum(image.width for image in images)
  136. max_height = max(image.height for image in images)
  137. corpus_image = PIL.Image.new(images[0].mode, (width, max_height))
  138. x_offset = 0
  139. for image in images:
  140. corpus_image.paste(image, (x_offset, 0))
  141. x_offset += image.width
  142.  
  143. corpus['gtp'] = corpus_gtp
  144. corpus['image'] = corpus_image
  145. corpus['data'] = np.array(corpus_image, dtype='float32')
  146. return corpus
  147.  
  148.  
  149. def pre_main(arguments):
  150. load_codebook(os.path.join('data', 'codebook.bin'))
  151. page_names = [os.path.splitext(filename)[0] for filename in sorted(os.listdir(PAGES_PATH))[:arguments.pages]]
  152. corpus = load_corpus(page_names)
  153.  
  154. results1 = collections.OrderedDict()
  155. results2 = collections.OrderedDict()
  156. for accumulator_percentile in range(0, 105, 5):
  157. print accumulator_percentile
  158. arguments.use_ifs = True
  159. arguments.use_accumulator = True
  160. arguments.accumulator_percentile = accumulator_percentile
  161. start = time.time()
  162. mean_average_precision = main(arguments, corpus)
  163. duration = int(time.time() - start)
  164. results1[accumulator_percentile] = mean_average_precision
  165. results2[accumulator_percentile] = duration
  166.  
  167. plt.plot(range(len(results1)), results1.values(), 'o')
  168. plt.xlabel('Accumulator Percentile')
  169. plt.ylabel('Mean Average Precision')
  170. plt.xticks(range(len(results1)), results1.keys())
  171. plt.grid(True)
  172. plt.ylim(0, 1)
  173. plt.tight_layout()
  174. plt.show()
  175.  
  176. plt.plot(range(len(results2)), results2.values(), 'o')
  177. plt.xlabel('Accumulator Percentile')
  178. plt.ylabel('Runtime')
  179. plt.xticks(range(len(results2)), results2.keys())
  180. plt.grid(True)
  181. plt.tight_layout()
  182. plt.show()
  183.  
  184.  
  185. def main(arguments, corpus):
  186. # Calculate SIFT data for corpus
  187. frames, descriptors = vlfeat.vl_dsift(
  188. corpus['image'] / corpus['data'].max(), step=arguments.step_size, size=arguments.cell_size,
  189. fast=True, float_descriptors=True)
  190.  
  191. # Find all frames and descriptors contained inside word boundaries (minus a cell margin of cell_size * 2)
  192. cell_margin = 2 * arguments.cell_size
  193. words_frames = []
  194. words_descriptors = []
  195. previous_frame_index = 0
  196. word_data_indices = collections.OrderedDict()
  197. word_coordinates = collections.OrderedDict()
  198. for word, coordinates in corpus['gtp'].items():
  199. # Filter word frames within word bounding box
  200. for variation, (x1, y1, x2, y2) in enumerate(coordinates):
  201. if arguments.cell_margin == 'none':
  202. mask = (
  203. (frames[:, 0] >= x1) & (frames[:, 1] >= y1) &
  204. (frames[:, 0] <= x2) & (frames[:, 1] <= y2))
  205. elif arguments.cell_margin == 'horizontal':
  206. mask = (
  207. (frames[:, 0] >= x1 + cell_margin) & (frames[:, 1] >= y1) &
  208. (frames[:, 0] <= x2 - cell_margin) & (frames[:, 1] <= y2))
  209. elif arguments.cell_margin == 'vertical':
  210. mask = (
  211. (frames[:, 0] >= x1) & (frames[:, 1] >= y1 + cell_margin) &
  212. (frames[:, 0] <= x2) & (frames[:, 1] <= y2 - cell_margin))
  213. elif arguments.cell_margin == 'both':
  214. mask = (
  215. (frames[:, 0] >= x1 + cell_margin) & (frames[:, 1] >= y1 + cell_margin) &
  216. (frames[:, 0] <= x2 - cell_margin) & (frames[:, 1] <= y2 - cell_margin))
  217. else:
  218. raise RuntimeError('dude what the fuck are you doing')
  219.  
  220. # Get matching frames/desc for the word
  221. word_frames = frames[mask]
  222. words_frames.append(word_frames)
  223. words_descriptors.append(descriptors[mask])
  224.  
  225. # Count how many frames are contained inside the bounding box
  226. frame_count = word_frames.shape[0]
  227.  
  228. # Note at which index and how many (following) frames/descs are part of a word
  229. key = word, variation
  230. word_data_indices[key] = previous_frame_index, frame_count
  231. word_coordinates[key] = x1, y1, x2, y2
  232. previous_frame_index += frame_count
  233. words_frames = np.concatenate(words_frames)
  234. words_descriptors = np.concatenate(words_descriptors)
  235.  
  236. if arguments.centroids == 4096:
  237. code_book = load_codebook(CODE_BOOK_PATH)
  238. labels, _ = scipy.cluster.vq.vq(words_descriptors, code_book)
  239. else:
  240. # Calculate labels
  241. _, labels = scipy.cluster.vq.kmeans2(
  242. words_descriptors, arguments.centroids, iter=arguments.k_means_iterations, minit='points')
  243.  
  244. # Word -> labels mapping
  245. # noinspection PyArgumentList
  246. word_labels = collections.OrderedDict(
  247. (key, labels[start:start + length]) for key, (start, length) in word_data_indices.items())
  248.  
  249. # Create (word, variation) -> spatial pyramid mapping
  250. # noinspection PyArgumentList
  251. spatial_pyramids = collections.OrderedDict(
  252. (key, make_spatial_pyramid(labels, arguments.centroids, arguments.spatial_pyramid_type))
  253. for key, labels in word_labels.items())
  254.  
  255. # Create IFS database
  256. ifs_height = len(spatial_pyramids.values()[0])
  257. ifs = [set() for count in range(ifs_height)]
  258. for word_index, spatial_pyramid in enumerate(spatial_pyramids.values()):
  259. for index, count in enumerate(spatial_pyramid):
  260. if count:
  261. ifs[index].add(word_index)
  262.  
  263. # Create word index -> variation set mapping
  264. word_variation_indices = collections.defaultdict(set)
  265. for word_index, (word, variation) in enumerate(spatial_pyramids.keys()):
  266. word_variation_indices[word].add(word_index)
  267.  
  268. # Find query in IFS
  269. spatial_pyramids_values = spatial_pyramids.values()
  270. word_coordinates_values = word_coordinates.values()
  271. average_precisions = []
  272. average_recalls = []
  273. for word_index, ((word, variation), query) in enumerate(spatial_pyramids.items()):
  274. # Skip words with no findable duplicates in the IFS database
  275. appearances = len(word_variation_indices[word]) - 1
  276. if not appearances:
  277. if arguments.verbose:
  278. print >> sys.stderr, 'No duplicate appearances for (%s, %d)!' % (word, variation)
  279. continue
  280.  
  281. if arguments.use_ifs:
  282. ifs_candidate_indices = list(itertools.chain(*(ifs[index] for index, count in enumerate(query) if count)))
  283. candidate_indices = set(ifs_candidate_indices)
  284. if not candidate_indices:
  285. if arguments.verbose:
  286. print >> sys.stderr, 'No candidates for (%s, %d) after IFS!' % (word, variation)
  287. average_precisions.append(0)
  288. average_recalls.append(0)
  289. continue
  290.  
  291. if arguments.use_accumulator:
  292. # noinspection PyArgumentList
  293. accumulator = collections.Counter(ifs_candidate_indices)
  294. # No candidates left after having applied the IFS
  295. if not accumulator:
  296. if arguments.verbose:
  297. print >> sys.stderr, 'No candidates for (%s, %d) after IFS + Accumulator!' % (word, variation)
  298. average_precisions.append(0)
  299. average_recalls.append(0)
  300. continue
  301.  
  302. most_common = accumulator.most_common()
  303. rankings = sorted(set(accumulator.values()))
  304. percentile_ranking = rankings[max(0, int(len(rankings) * arguments.accumulator_percentile / 100.0) - 1)]
  305. candidate_indices = set(
  306. index for index, count in
  307. list(itertools.takewhile(lambda item: item[1] >= percentile_ranking, most_common)))
  308. else:
  309. candidate_indices = set(range(len(spatial_pyramids)))
  310.  
  311. candidate_indices -= {word_index}
  312. if not candidate_indices:
  313. if arguments.verbose:
  314. print >> sys.stderr, 'No candidates for (%s, %d)' % (word, variation)
  315. average_precisions.append(0)
  316. average_recalls.append(0)
  317. continue
  318.  
  319. candidate_pyramids = np.array([spatial_pyramids_values[index] for index in candidate_indices])
  320. query = query.reshape((1, query.shape[0]))
  321. distances = scipy.spatial.distance.cdist(query, candidate_pyramids, metric=arguments.distance_metric)[0]
  322.  
  323. # Translate index in distance array to index of candidate
  324. distances_indices = range(distances.shape[0])
  325. distance_index_to_candidate_index = {
  326. distance_index: candidate_index for distance_index, candidate_index in
  327. zip(distances_indices, candidate_indices)}
  328. distances_sorted_indices = np.argsort(distances)
  329. sorted_candidate_indices = [
  330. distance_index_to_candidate_index[distance_index] for distance_index in distances_sorted_indices]
  331.  
  332. hits = [1 if index in word_variation_indices[word] else 0 for index in sorted_candidate_indices]
  333. true_positives = sum(hits)
  334. # Calculate accumulated hits at index
  335. hits_at_k = []
  336. current_hits = 0
  337. for hit in hits:
  338. if hit:
  339. current_hits += 1
  340. hits_at_k.append(current_hits)
  341. average_precision = sum(
  342. (current_hits / float(index)) * hit for index, (hit, current_hits) in
  343. enumerate(zip(hits, hits_at_k), start=1)) / float(appearances)
  344. average_precisions.append(average_precision)
  345. average_recalls.append(true_positives / float(appearances))
  346.  
  347. if arguments.save_images:
  348. match_images_path = os.path.join(MATCH_IMAGES_PATH, '%s_%d' % (word, variation))
  349. makedirs(match_images_path, exist_ok=True)
  350. coordinates = word_coordinates_values[word_index]
  351. corpus['image'].crop(coordinates).save(os.path.join(match_images_path, '0_original.png'))
  352.  
  353. for rank, candidate_word_index in enumerate(sorted_candidate_indices, start=1):
  354. coordinates = word_coordinates_values[candidate_word_index]
  355. path = os.path.join(match_images_path, 'candidate_%d.png' % rank)
  356. corpus['image'].crop(coordinates).save(path)
  357.  
  358. # print 'Word %s (Variation: %d): %.2f%%' % (word, variation, average_precision * 100)
  359.  
  360. print 'Mean Recall: %f' % (np.mean(average_recalls) * 100)
  361. mean_average_precision = np.mean(average_precisions)
  362. print 'Mean Average Precision: %f' % (mean_average_precision * 100)
  363. return mean_average_precision
  364.  
  365.  
  366. if __name__ == '__main__':
  367. arguments = argument_parser.parse_args()
  368. pre_main(arguments)
Add Comment
Please, Sign In to add comment