Advertisement
Guest User

Untitled

a guest
May 22nd, 2018
86
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
Python 5.76 KB | None | 0 0
  1. import os
  2. import numpy as np
  3.  
  4. from random import shuffle
  5.  
  6. from sys import stdout
  7. import sys
  8.  
  9.  
  10. '''Подготовим обучающую выборку:'''
  11. '''*****************************************************'''
  12. #взять произвольное изображение
  13. def get_random_image(img_groups, group_names, gid):
  14.     gname = group_names[gid]
  15.     photos = img_groups[gname]
  16.     pid = np.random.choice(np.arange(len(photos)), size=1)[0]
  17.     pname = photos[pid]
  18.     return gname + pname + ".jpg"
  19.  
  20. #взять начальное изображение
  21. def get_first_image(img_groups, group_names, gid):
  22.     gname = group_names[gid]
  23.     pname = '000'
  24.     return gname + pname + ".jpg"
  25.  
  26. '''Генерации данных:'''
  27. def create_triples(rezSize, images, pointPP):
  28.     ''' Функция перебором создаёт тройки для валидации после обучения
  29.        rezSize - параметр устанавливает, сколько данных нужно в итоге получить
  30.        images - массив номеров изображений, которые будем использовать для валидации
  31.        
  32.        Возвращает список данных для валидации сети.
  33.    '''
  34.    
  35.     #распределяем изображения по группам (отдельная группа для отдельной области)
  36.     img_groups = {}
  37.    
  38.     print('Произвольная генерация данных:')
  39.     i_lim = 0
  40.     limit = rezSize//2
  41.     '''сделаем постепенную загрузку картинок (ведь, мы знаем сколько их)'''
  42.     #для одной картинки вырезано 400 точек по 79 преобразований на каждую:
  43.     #нумерация точек: от 0 до 399(pointsPerPict)
  44.     #нумерация картинок: от 1 до 16(imageCount)
  45.     #нумерация преобразований: от 0 до 78
  46.     pointsPerPict = pointPP #не больше 400!
  47.     trans = 79
  48.    
  49.     #цикл по картинкам:
  50.     for i in images:
  51.         curr_pict = '0'*(3 - len(str(i))) + str(i)
  52.         #цикл по точкам:
  53.         for j in range(0, pointsPerPict, 1):
  54.             #группа:
  55.             aid = str(j) + curr_pict
  56.             img_groups[aid] = []
  57.             #цикл по реобразованиям точек:
  58.             for k in range(0, trans):
  59.                 img_groups[aid].append('0'*(3 - len(str(k))) + str(k))
  60.                 i_lim += 1
  61.                 if i_lim >= limit:
  62.                     break
  63.            
  64.             if i_lim >= limit:
  65.                     break
  66.        
  67.         if i_lim >= limit:
  68.             break
  69.    
  70.     #создаем хорошие и плохие тройки (изображения одной области и разных соотв.)
  71.     pos_triples, neg_triples = [], []
  72.    
  73.     #позитивные пары - это комбинация изображений из одной группы
  74.     print('Хорошие тройки:')
  75.     for key in img_groups.keys():
  76.         print('Область %s' % (key))
  77.         triples = []
  78.         for i in range(len(img_groups[key])):
  79.             #triples.append((key + '0'*(3 - len(str(i))) + str(i) + ".jpg", key + "000" + ".jpg", 1))  
  80.             triples.append((key + img_groups[key][i] + '.jpg', key + '000' + '.jpg', 1))  
  81.             stdout.write("\rтройка # %d / %d" % (i, len(img_groups[key])-1))
  82.             stdout.flush()
  83.         pos_triples.extend(triples)
  84.         print("")
  85.        
  86.     #нужно такое же число плохих:
  87.     print('Плохие тройки:')
  88.     group_names = list(img_groups.keys())
  89.     for i in range(len(pos_triples)):
  90.         #подбор плохих пар только из точек одного изображения
  91.         #делается для того, чтобы сеть лучше различала похожие точки
  92.         g1, g2 = '',''
  93.         while True:
  94.             g1, g2 = np.random.choice(np.arange(len(group_names)), size=2, replace=False)
  95.             #если номера картинок совпадают, а номера точек - нет и одно из преобразований - 0
  96.             if group_names[g1][-3:] == group_names[g2][-3:] and \
  97.                group_names[g1][:-3] != group_names[g2][:-3]:
  98.                 break
  99.            
  100.         left =  get_first_image(img_groups, group_names, g1)
  101.         right = get_random_image(img_groups, group_names, g2)
  102.         neg_triples.append((left, right, 0))
  103.         stdout.write("\rтройка # %d / %d" % (i, len(pos_triples)-1))
  104.         stdout.flush()
  105.     print("")
  106.    
  107.     #добавляем плохие к хорошим:
  108.     pos_triples.extend(neg_triples)
  109.     #перемешиваем:
  110.     shuffle(pos_triples)
  111.     return pos_triples
  112.  
  113. def writeToFile(path, source_filename, data):
  114.     ''' Записывает data в текстовый файл:
  115.        элемент из data:
  116.            ('область0.jpg', 'область1.jpg', 1/0)
  117.        Пример:
  118.            Элемент из data:
  119.            ('0001070.jpg', '0001071.jpg', 1)
  120.    '''
  121.    
  122.     with open(os.path.join(path, source_filename), "w") as text_file:
  123.         for x in data:
  124.             text_file.write('%s %s %d' % (x[0], x[1], x[2]))
  125.             text_file.write('\n')
  126.            
  127. '''Основная программа: '''
  128. size=50000
  129. images=[x for x in range(1,2,1)]
  130. pointsPerPicture=400
  131.  
  132. triples = create_triples(size, images, pointsPerPicture)
  133.  
  134. print(len(triples))
  135. #запись в файл:
  136. writeToFile(os.getcwd(), 'setHM0.txt', triples)
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement