Not a member of Pastebin yet?
Sign Up,
it unlocks many cool features!
- import argparse
- import collections
- import itertools
- import os
- import sys
- import time
- try:
- import cPickle as pickle
- except ImportError:
- import pickle
- import PIL.Image
- import numpy as np
- import vlfeat
- import scipy.cluster.vq
- import scipy.spatial.distance
- import matplotlib.pyplot as plt
- PAGES_PATH = os.path.join('data', 'pages')
- GT_PATH = os.path.join('data', 'GT')
- IFS_MATCH_IMAGES_PATH = os.path.join('ifs_match_images')
- MATCH_IMAGES_PATH = os.path.join('match_images')
- CODE_BOOK_PATH = os.path.join('data', 'codebook.bin')
- SPATIAL_PYRAMID_TYPES = ['L', 'R', 'G', 'GL', 'GR', 'LR', 'GLR']
- CELL_MARGIN_TYPES = ['none', 'horizontal', 'vertical', 'both']
- argument_parser = argparse.ArgumentParser()
- argument_parser.add_argument('--step-size', '-s', type=int, default=15)
- argument_parser.add_argument('--cell-size', '-c', type=int, default=3)
- argument_parser.add_argument('--centroids', '-C', type=int, default=40)
- argument_parser.add_argument('--k-means-iterations', '-k', type=int, default=20)
- argument_parser.add_argument('--distance-metric', choices=['cityblock', 'cosine', 'euclidean'], default='cosine')
- argument_parser.add_argument('--spatial-pyramid-type', '-S', choices=SPATIAL_PYRAMID_TYPES, default='LR')
- argument_parser.add_argument('--pages', '-p', type=int, default=1)
- argument_parser.add_argument('--accumulator-percentile', '-a', type=float, default=95.0)
- argument_parser.add_argument('--use-ifs', '-I', action='store_true')
- argument_parser.add_argument('--use-accumulator', '-A', action='store_true')
- argument_parser.add_argument('--save-images', '-sa', action='store_true')
- argument_parser.add_argument('--verbose', action='store_true')
- argument_parser.add_argument('--cell-margin', choices=CELL_MARGIN_TYPES, default='horizontal')
- SpatialPyramid = collections.namedtuple('SpatialPyramid', ['global_', 'left', 'right'])
- def makedirs(name, mode=0777, exist_ok=False):
- if not exist_ok:
- return os.makedirs(name, mode)
- # Taken from Python 3
- try:
- os.makedirs(name, mode)
- except OSError:
- # Cannot rely on checking for EEXIST, since the operating system
- # could give priority to other errors like EACCES or EROFS
- if not exist_ok or not os.path.isdir(name):
- raise
- def load_gtp_file(path):
- entries = collections.defaultdict(list)
- with open(path) as file:
- for line in (line for line in (line.strip() for line in file) if line):
- x1, y1, x2, y2, word = line.split()
- entries[word].append((int(x1), int(y1), int(x2), int(y2)))
- return entries
- def load_codebook(path):
- input_file = open(path, 'r')
- code_book = np.fromfile(input_file, dtype='float32')
- code_book = np.reshape(code_book, (4096, 128))
- return code_book
- def make_spatial_pyramid(data, length, type_='GLR'):
- count = len(data)
- left_index = int(np.floor(count / 2))
- right_index = int(np.ceil(count / 2))
- if type_ == 'L':
- data = [], data[:left_index], []
- elif type_ == 'R':
- data = [], [], data[right_index:]
- elif type_ == 'G':
- data = data, [], []
- elif type_ == 'GL':
- data = data, data[:left_index], []
- elif type_ == 'GR':
- data = data, [], data[right_index:]
- elif type_ == 'LR':
- data = [], data[:left_index], data[right_index:]
- elif type_ == 'GLR':
- data = data, data[:left_index], data[right_index:]
- else:
- raise ValueError('unknown spatial pyramid type: %r' % type_)
- spatial_pyramid = SpatialPyramid(*(np.bincount(datum, minlength=length) for datum in data))
- return np.concatenate(spatial_pyramid)
- def load_corpus(page_names):
- defaultdict_factory = lambda: collections.defaultdict(defaultdict_factory)
- corpus = collections.defaultdict(defaultdict_factory)
- offset = 0
- images = []
- corpus_gtp = collections.defaultdict(list)
- for page_name in page_names:
- corpus['pages'][page_name]['offset'] = offset
- # Load page image
- image_path = os.path.join(PAGES_PATH, '%s.png' % page_name)
- corpus['pages'][page_name]['image_path'] = image_path
- image = PIL.Image.open(image_path)
- corpus['pages'][page_name]['image'] = image
- images.append(image)
- # Load page GTP
- gtp_path = os.path.join(GT_PATH, '%s.gtp' % page_name)
- corpus['pages'][page_name]['gtp_path'] = gtp_path
- gtp = load_gtp_file(gtp_path)
- corpus['pages'][page_name]['gtp'] = gtp
- # Update global corpus GTP with current offset
- for word, coordinates in gtp.items():
- for x1, y1, x2, y2 in coordinates:
- corpus_gtp[word].append((x1 + offset, y1, x2 + offset, y2))
- offset += image.width
- # Create Corpus image by concatenating page images horizontally
- width = sum(image.width for image in images)
- max_height = max(image.height for image in images)
- corpus_image = PIL.Image.new(images[0].mode, (width, max_height))
- x_offset = 0
- for image in images:
- corpus_image.paste(image, (x_offset, 0))
- x_offset += image.width
- corpus['gtp'] = corpus_gtp
- corpus['image'] = corpus_image
- corpus['data'] = np.array(corpus_image, dtype='float32')
- return corpus
- def pre_main(arguments):
- load_codebook(os.path.join('data', 'codebook.bin'))
- page_names = [os.path.splitext(filename)[0] for filename in sorted(os.listdir(PAGES_PATH))[:arguments.pages]]
- corpus = load_corpus(page_names)
- results1 = collections.OrderedDict()
- results2 = collections.OrderedDict()
- for accumulator_percentile in range(0, 105, 5):
- print accumulator_percentile
- arguments.use_ifs = True
- arguments.use_accumulator = True
- arguments.accumulator_percentile = accumulator_percentile
- start = time.time()
- mean_average_precision = main(arguments, corpus)
- duration = int(time.time() - start)
- results1[accumulator_percentile] = mean_average_precision
- results2[accumulator_percentile] = duration
- plt.plot(range(len(results1)), results1.values(), 'o')
- plt.xlabel('Accumulator Percentile')
- plt.ylabel('Mean Average Precision')
- plt.xticks(range(len(results1)), results1.keys())
- plt.grid(True)
- plt.ylim(0, 1)
- plt.tight_layout()
- plt.show()
- plt.plot(range(len(results2)), results2.values(), 'o')
- plt.xlabel('Accumulator Percentile')
- plt.ylabel('Runtime')
- plt.xticks(range(len(results2)), results2.keys())
- plt.grid(True)
- plt.tight_layout()
- plt.show()
- def main(arguments, corpus):
- # Calculate SIFT data for corpus
- frames, descriptors = vlfeat.vl_dsift(
- corpus['image'] / corpus['data'].max(), step=arguments.step_size, size=arguments.cell_size,
- fast=True, float_descriptors=True)
- # Find all frames and descriptors contained inside word boundaries (minus a cell margin of cell_size * 2)
- cell_margin = 2 * arguments.cell_size
- words_frames = []
- words_descriptors = []
- previous_frame_index = 0
- word_data_indices = collections.OrderedDict()
- word_coordinates = collections.OrderedDict()
- for word, coordinates in corpus['gtp'].items():
- # Filter word frames within word bounding box
- for variation, (x1, y1, x2, y2) in enumerate(coordinates):
- if arguments.cell_margin == 'none':
- mask = (
- (frames[:, 0] >= x1) & (frames[:, 1] >= y1) &
- (frames[:, 0] <= x2) & (frames[:, 1] <= y2))
- elif arguments.cell_margin == 'horizontal':
- mask = (
- (frames[:, 0] >= x1 + cell_margin) & (frames[:, 1] >= y1) &
- (frames[:, 0] <= x2 - cell_margin) & (frames[:, 1] <= y2))
- elif arguments.cell_margin == 'vertical':
- mask = (
- (frames[:, 0] >= x1) & (frames[:, 1] >= y1 + cell_margin) &
- (frames[:, 0] <= x2) & (frames[:, 1] <= y2 - cell_margin))
- elif arguments.cell_margin == 'both':
- mask = (
- (frames[:, 0] >= x1 + cell_margin) & (frames[:, 1] >= y1 + cell_margin) &
- (frames[:, 0] <= x2 - cell_margin) & (frames[:, 1] <= y2 - cell_margin))
- else:
- raise RuntimeError('dude what the fuck are you doing')
- # Get matching frames/desc for the word
- word_frames = frames[mask]
- words_frames.append(word_frames)
- words_descriptors.append(descriptors[mask])
- # Count how many frames are contained inside the bounding box
- frame_count = word_frames.shape[0]
- # Note at which index and how many (following) frames/descs are part of a word
- key = word, variation
- word_data_indices[key] = previous_frame_index, frame_count
- word_coordinates[key] = x1, y1, x2, y2
- previous_frame_index += frame_count
- words_frames = np.concatenate(words_frames)
- words_descriptors = np.concatenate(words_descriptors)
- if arguments.centroids == 4096:
- code_book = load_codebook(CODE_BOOK_PATH)
- labels, _ = scipy.cluster.vq.vq(words_descriptors, code_book)
- else:
- # Calculate labels
- _, labels = scipy.cluster.vq.kmeans2(
- words_descriptors, arguments.centroids, iter=arguments.k_means_iterations, minit='points')
- # Word -> labels mapping
- # noinspection PyArgumentList
- word_labels = collections.OrderedDict(
- (key, labels[start:start + length]) for key, (start, length) in word_data_indices.items())
- # Create (word, variation) -> spatial pyramid mapping
- # noinspection PyArgumentList
- spatial_pyramids = collections.OrderedDict(
- (key, make_spatial_pyramid(labels, arguments.centroids, arguments.spatial_pyramid_type))
- for key, labels in word_labels.items())
- # Create IFS database
- ifs_height = len(spatial_pyramids.values()[0])
- ifs = [set() for count in range(ifs_height)]
- for word_index, spatial_pyramid in enumerate(spatial_pyramids.values()):
- for index, count in enumerate(spatial_pyramid):
- if count:
- ifs[index].add(word_index)
- # Create word index -> variation set mapping
- word_variation_indices = collections.defaultdict(set)
- for word_index, (word, variation) in enumerate(spatial_pyramids.keys()):
- word_variation_indices[word].add(word_index)
- # Find query in IFS
- spatial_pyramids_values = spatial_pyramids.values()
- word_coordinates_values = word_coordinates.values()
- average_precisions = []
- average_recalls = []
- for word_index, ((word, variation), query) in enumerate(spatial_pyramids.items()):
- # Skip words with no findable duplicates in the IFS database
- appearances = len(word_variation_indices[word]) - 1
- if not appearances:
- if arguments.verbose:
- print >> sys.stderr, 'No duplicate appearances for (%s, %d)!' % (word, variation)
- continue
- if arguments.use_ifs:
- ifs_candidate_indices = list(itertools.chain(*(ifs[index] for index, count in enumerate(query) if count)))
- candidate_indices = set(ifs_candidate_indices)
- if not candidate_indices:
- if arguments.verbose:
- print >> sys.stderr, 'No candidates for (%s, %d) after IFS!' % (word, variation)
- average_precisions.append(0)
- average_recalls.append(0)
- continue
- if arguments.use_accumulator:
- # noinspection PyArgumentList
- accumulator = collections.Counter(ifs_candidate_indices)
- # No candidates left after having applied the IFS
- if not accumulator:
- if arguments.verbose:
- print >> sys.stderr, 'No candidates for (%s, %d) after IFS + Accumulator!' % (word, variation)
- average_precisions.append(0)
- average_recalls.append(0)
- continue
- most_common = accumulator.most_common()
- rankings = sorted(set(accumulator.values()))
- percentile_ranking = rankings[max(0, int(len(rankings) * arguments.accumulator_percentile / 100.0) - 1)]
- candidate_indices = set(
- index for index, count in
- list(itertools.takewhile(lambda item: item[1] >= percentile_ranking, most_common)))
- else:
- candidate_indices = set(range(len(spatial_pyramids)))
- candidate_indices -= {word_index}
- if not candidate_indices:
- if arguments.verbose:
- print >> sys.stderr, 'No candidates for (%s, %d)' % (word, variation)
- average_precisions.append(0)
- average_recalls.append(0)
- continue
- candidate_pyramids = np.array([spatial_pyramids_values[index] for index in candidate_indices])
- query = query.reshape((1, query.shape[0]))
- distances = scipy.spatial.distance.cdist(query, candidate_pyramids, metric=arguments.distance_metric)[0]
- # Translate index in distance array to index of candidate
- distances_indices = range(distances.shape[0])
- distance_index_to_candidate_index = {
- distance_index: candidate_index for distance_index, candidate_index in
- zip(distances_indices, candidate_indices)}
- distances_sorted_indices = np.argsort(distances)
- sorted_candidate_indices = [
- distance_index_to_candidate_index[distance_index] for distance_index in distances_sorted_indices]
- hits = [1 if index in word_variation_indices[word] else 0 for index in sorted_candidate_indices]
- true_positives = sum(hits)
- # Calculate accumulated hits at index
- hits_at_k = []
- current_hits = 0
- for hit in hits:
- if hit:
- current_hits += 1
- hits_at_k.append(current_hits)
- average_precision = sum(
- (current_hits / float(index)) * hit for index, (hit, current_hits) in
- enumerate(zip(hits, hits_at_k), start=1)) / float(appearances)
- average_precisions.append(average_precision)
- average_recalls.append(true_positives / float(appearances))
- if arguments.save_images:
- match_images_path = os.path.join(MATCH_IMAGES_PATH, '%s_%d' % (word, variation))
- makedirs(match_images_path, exist_ok=True)
- coordinates = word_coordinates_values[word_index]
- corpus['image'].crop(coordinates).save(os.path.join(match_images_path, '0_original.png'))
- for rank, candidate_word_index in enumerate(sorted_candidate_indices, start=1):
- coordinates = word_coordinates_values[candidate_word_index]
- path = os.path.join(match_images_path, 'candidate_%d.png' % rank)
- corpus['image'].crop(coordinates).save(path)
- # print 'Word %s (Variation: %d): %.2f%%' % (word, variation, average_precision * 100)
- print 'Mean Recall: %f' % (np.mean(average_recalls) * 100)
- mean_average_precision = np.mean(average_precisions)
- print 'Mean Average Precision: %f' % (mean_average_precision * 100)
- return mean_average_precision
- if __name__ == '__main__':
- arguments = argument_parser.parse_args()
- pre_main(arguments)
Add Comment
Please, Sign In to add comment