Advertisement
Not a member of Pastebin yet?
Sign Up,
it unlocks many cool features!
- import os
- import numpy as np
- from random import shuffle
- from sys import stdout
- import sys
- '''Подготовим обучающую выборку:'''
- '''*****************************************************'''
- #взять произвольное изображение
- def get_random_image(img_groups, group_names, gid):
- gname = group_names[gid]
- photos = img_groups[gname]
- pid = np.random.choice(np.arange(len(photos)), size=1)[0]
- pname = photos[pid]
- return gname + pname + ".jpg"
- #взять начальное изображение
- def get_first_image(img_groups, group_names, gid):
- gname = group_names[gid]
- pname = '000'
- return gname + pname + ".jpg"
- '''Генерации данных:'''
- def create_triples(rezSize, images, pointPP):
- ''' Функция перебором создаёт тройки для валидации после обучения
- rezSize - параметр устанавливает, сколько данных нужно в итоге получить
- images - массив номеров изображений, которые будем использовать для валидации
- Возвращает список данных для валидации сети.
- '''
- #распределяем изображения по группам (отдельная группа для отдельной области)
- img_groups = {}
- print('Произвольная генерация данных:')
- i_lim = 0
- limit = rezSize//2
- '''сделаем постепенную загрузку картинок (ведь, мы знаем сколько их)'''
- #для одной картинки вырезано 400 точек по 79 преобразований на каждую:
- #нумерация точек: от 0 до 399(pointsPerPict)
- #нумерация картинок: от 1 до 16(imageCount)
- #нумерация преобразований: от 0 до 78
- pointsPerPict = pointPP #не больше 400!
- trans = 79
- #цикл по картинкам:
- for i in images:
- curr_pict = '0'*(3 - len(str(i))) + str(i)
- #цикл по точкам:
- for j in range(0, pointsPerPict, 1):
- #группа:
- aid = str(j) + curr_pict
- img_groups[aid] = []
- #цикл по реобразованиям точек:
- for k in range(0, trans):
- img_groups[aid].append('0'*(3 - len(str(k))) + str(k))
- i_lim += 1
- if i_lim >= limit:
- break
- if i_lim >= limit:
- break
- if i_lim >= limit:
- break
- #создаем хорошие и плохие тройки (изображения одной области и разных соотв.)
- pos_triples, neg_triples = [], []
- #позитивные пары - это комбинация изображений из одной группы
- print('Хорошие тройки:')
- for key in img_groups.keys():
- print('Область %s' % (key))
- triples = []
- for i in range(len(img_groups[key])):
- #triples.append((key + '0'*(3 - len(str(i))) + str(i) + ".jpg", key + "000" + ".jpg", 1))
- triples.append((key + img_groups[key][i] + '.jpg', key + '000' + '.jpg', 1))
- stdout.write("\rтройка # %d / %d" % (i, len(img_groups[key])-1))
- stdout.flush()
- pos_triples.extend(triples)
- print("")
- #нужно такое же число плохих:
- print('Плохие тройки:')
- group_names = list(img_groups.keys())
- for i in range(len(pos_triples)):
- #подбор плохих пар только из точек одного изображения
- #делается для того, чтобы сеть лучше различала похожие точки
- g1, g2 = '',''
- while True:
- g1, g2 = np.random.choice(np.arange(len(group_names)), size=2, replace=False)
- #если номера картинок совпадают, а номера точек - нет и одно из преобразований - 0
- if group_names[g1][-3:] == group_names[g2][-3:] and \
- group_names[g1][:-3] != group_names[g2][:-3]:
- break
- left = get_first_image(img_groups, group_names, g1)
- right = get_random_image(img_groups, group_names, g2)
- neg_triples.append((left, right, 0))
- stdout.write("\rтройка # %d / %d" % (i, len(pos_triples)-1))
- stdout.flush()
- print("")
- #добавляем плохие к хорошим:
- pos_triples.extend(neg_triples)
- #перемешиваем:
- shuffle(pos_triples)
- return pos_triples
- def writeToFile(path, source_filename, data):
- ''' Записывает data в текстовый файл:
- элемент из data:
- ('область0.jpg', 'область1.jpg', 1/0)
- Пример:
- Элемент из data:
- ('0001070.jpg', '0001071.jpg', 1)
- '''
- with open(os.path.join(path, source_filename), "w") as text_file:
- for x in data:
- text_file.write('%s %s %d' % (x[0], x[1], x[2]))
- text_file.write('\n')
- '''Основная программа: '''
- size=50000
- images=[x for x in range(1,2,1)]
- pointsPerPicture=400
- triples = create_triples(size, images, pointsPerPicture)
- print(len(triples))
- #запись в файл:
- writeToFile(os.getcwd(), 'setHM0.txt', triples)
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement