Advertisement
Guest User

Untitled

a guest
Oct 21st, 2019
82
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
text 12.63 KB | None | 0 0
  1. import time
  2. from typing import List, Tuple, Dict
  3. import os
  4. import random
  5. import time
  6. from typing import List, Tuple, Dict
  7.  
  8. import PIL
  9. import matplotlib.pyplot as plt
  10. import numpy
  11. from PIL import Image, ImageEnhance, ImageFilter
  12. from sklearn.ensemble import RandomForestClassifier
  13.  
  14.  
  15. def normalize_and_ravel(image):
  16. I = numpy.array(image)
  17. del image
  18. I = numpy.ravel(I, order='C')
  19. f = lambda x: x / 255
  20. normalized = f(I)
  21. return normalized
  22.  
  23.  
  24. def normalize_and_ravel_set(set_pictures):
  25. new_dict = dict()
  26. for a, b in set_pictures.items():
  27. new_dict[a] = list()
  28. for i in b:
  29. new_dict[a].append(normalize_and_ravel(i))
  30. return new_dict
  31.  
  32.  
  33. def find_freq(data: Dict[str, List[str]]):
  34. freq = dict()
  35. for a, b in data.items():
  36. freq[int(a)] = len(b)
  37. return freq
  38.  
  39.  
  40. def plot_freq(freq):
  41. y = list()
  42. x = list()
  43. for a, b in freq.items():
  44. x.append(int(a))
  45. y.append(b)
  46. plt.figure()
  47. plt.bar(x, y)
  48.  
  49.  
  50. def return_top_level_dirs(path_to_dir: str) -> List[str]:
  51. """
  52.  
  53. :param path_to_dir: path to the directory
  54. :return: list of names of only top level directories located in the directory which location is described by path_to_dir
  55. """
  56. direct = next(os.walk(path_to_dir))[1]
  57. return direct
  58.  
  59.  
  60. def stack_vectors_prepare_labels(training_dict):
  61. labels = []
  62. data = []
  63. for a, b in training_dict.items():
  64. for i in b:
  65. data.append(i)
  66. labels.append(a)
  67. return numpy.array(data), numpy.array(labels)
  68.  
  69.  
  70. def split_vids(data: List[List[str]]) -> Tuple[List[List[str]], List[List[str]]]:
  71. data_1 = list()
  72. data_2 = list()
  73. random.shuffle(data)
  74. for elem in range(len(data)):
  75. if elem >= int(len(data) * 0.8):
  76. data_2.append(data[elem])
  77. else:
  78. data_1.append(data[elem])
  79. return data_2, data_1
  80.  
  81.  
  82. def change_picture(image):
  83. brightness_param = random.uniform(0.5, 2)
  84. contrast_param = random.uniform(0.5, 2)
  85. blur_param = random.uniform(0, 1.25)
  86. brightness = PIL.ImageEnhance.Brightness(image)
  87. image = brightness.enhance(brightness_param)
  88. contrast = PIL.ImageEnhance.Contrast(image)
  89. image = contrast.enhance(contrast_param)
  90.  
  91. image = image.filter(ImageFilter.GaussianBlur(radius=blur_param))
  92.  
  93. return image
  94.  
  95.  
  96. def augment(training_set_pictures, freq, support_pick=None):
  97. max = 0
  98. for a, b in freq.items():
  99. if b > max:
  100. max = b
  101. for a, b in freq.items():
  102. length = len(training_set_pictures[a])
  103. for i in range(0, max - b):
  104. if not support_pick:
  105. random_picture_index = random.randint(0, length - 1)
  106. random_picture = training_set_pictures[a][random_picture_index]
  107. else:
  108. random_picture = support_pick(a)
  109. training_set_pictures[a].append(change_picture(random_picture))
  110. if support_pick:
  111. for class_ in range(0, 43):
  112. if class_ not in training_set_pictures:
  113. training_set_pictures[class_] = []
  114. for i in range(max):
  115. random_picture = support_pick(class_)
  116. training_set_pictures[class_].append(random_picture)
  117. return training_set_pictures
  118.  
  119.  
  120. def format_set(set, size):
  121. formatted_set = dict()
  122. for a, b in set.items():
  123. formatted_set[a] = list()
  124. for i in b:
  125. formatted_set[a].append(change_to_format(i, size))
  126. return formatted_set
  127.  
  128.  
  129. # Padd rectangular images to square shape - add zero pixels
  130. def return_files_in_dir(path_to_dir: str) -> List[str]:
  131. """
  132. :param path_to_dir: path to the directory
  133. :return: list of paths to files stored in the directory which location is described by path_to_dir
  134. """
  135. files = []
  136. # r=root, d=directories, f = files
  137. for r, d, f in os.walk(path_to_dir):
  138. for file in f:
  139. files.append(os.path.join(r, file))
  140. return files
  141.  
  142.  
  143. def create_images_set(files_set):
  144. images_set = dict()
  145. for a, b in files_set.items():
  146. images_set[int(a)] = list()
  147. for path in b:
  148. image = Image.open(path)
  149. image.load()
  150. images_set[int(a)].append(image)
  151. return images_set
  152.  
  153.  
  154. def change_to_format(old_im, size):
  155. old_size = old_im.size
  156. width, height = old_size
  157. if width > height:
  158. desired_size = width
  159. else:
  160. desired_size = height
  161. new_size = (desired_size, desired_size)
  162. new_im = Image.new("RGB", new_size)
  163.  
  164. new_im.paste(old_im, ((new_size[0] - old_size[0]),
  165. (new_size[1] - old_size[1])))
  166. new_im = new_im.resize(size, Image.ANTIALIAS)
  167. return new_im
  168.  
  169.  
  170. def get_num_of_vids_for_class(files_in_dict: List[str]) -> int:
  171. files = files_in_dict
  172. max = 0
  173. for j in files:
  174. if j.split(".")[-1] == "ppm":
  175. splitted = (j.split("/")[-1]).split("_")[0]
  176. if int(splitted) > max:
  177. max = int(splitted)
  178. return max + 1
  179.  
  180.  
  181. def get_classes(path_to_dir):
  182. return return_top_level_dirs(path_to_dir)
  183.  
  184.  
  185. def get_images_of_class(class_name, path_to_dir):
  186. path_to_dir = path_to_dir + "/" + class_name
  187. files = return_files_in_dir(path_to_dir)
  188. images = []
  189. for file in files:
  190. if file.split(".")[-1] == "ppm":
  191. images.append(file)
  192. return images
  193.  
  194.  
  195. def split_vids_within_one_class(cls, path_to_dir):
  196. training_set_files = []
  197. testing_set = []
  198. cls_images = get_images_of_class(cls, path_to_dir)
  199. num_of_vids = get_num_of_vids_for_class(cls_images)
  200. vids_list: List[List[str]] = list()
  201. for _ in range(num_of_vids):
  202. vids_list.append(list())
  203. for image in cls_images:
  204. vid_id = int((image.split("/")[-1]).split("_")[0])
  205. vids_list[vid_id].append(image)
  206. testing_data, training_data = split_vids(vids_list)
  207. for vid in testing_data:
  208. for image in vid:
  209. testing_set.append(image)
  210. for vid in training_data:
  211. for image in vid:
  212. training_set_files.append(image)
  213. return training_set_files, testing_set
  214.  
  215.  
  216. def get_track_files(file_set, track):
  217. from collections import defaultdict
  218. result = defaultdict(list)
  219. for class_, files in file_set.items():
  220. for path_str in files:
  221. from pathlib import Path
  222. path = Path(path_str)
  223. current_track = int(path.parts[-1].split("_")[0])
  224. if current_track == track:
  225. result[class_].append(path_str)
  226. return result
  227.  
  228.  
  229. def prepare_batch_data(path_to_dir, size, augment_flag):
  230. """
  231. Like `prepare_data`, but produces a generator which will give batches of
  232. training data to better fit in memory.
  233. """
  234. classes: List[str] = get_classes(path_to_dir)
  235. training_set_files: Dict[str, List[str]] = dict()
  236. testing_set: Dict[str, List[str]] = dict()
  237. for cls in classes:
  238. training_set_files[cls], testing_set[cls] = split_vids_within_one_class(cls, path_to_dir)
  239.  
  240. def support_pick(class_):
  241. class_str = str(class_).zfill(5)
  242. random_file = random.choice(training_set_files[class_str])
  243. wtf = create_images_set({0: [random_file]})
  244. wtf = format_set(wtf, size)
  245. return wtf[0][0]
  246.  
  247. def training_set_generator():
  248. # taking one track per class
  249. track = 0
  250. first_images = None
  251. while True: # breaks when there are no more actual images
  252. images = get_track_files(training_set_files, track)
  253. if not images:
  254. break
  255.  
  256. images = create_images_set(images)
  257. images = format_set(images, size)
  258. freq = find_freq(images)
  259. if augment_flag:
  260. images = augment(images, freq, support_pick)
  261. images = normalize_and_ravel_set(images)
  262. data, labels = stack_vectors_prepare_labels(images)
  263. yield data, labels
  264.  
  265. track += 1
  266.  
  267. validation_set_pictures = create_images_set(testing_set)
  268. validation_set_pictures = format_set(validation_set_pictures, size)
  269. validation_set_pictures = normalize_and_ravel_set(validation_set_pictures)
  270. validation_data, validation_labels = stack_vectors_prepare_labels(validation_set_pictures)
  271.  
  272. return training_set_generator, validation_data, validation_labels
  273.  
  274.  
  275. def prepare_data(path_to_dir, size, augument_flag):
  276. classes: List[str] = get_classes(path_to_dir)
  277. training_set_files: Dict[str, List[str]] = dict()
  278. testing_set: Dict[str, List[str]] = dict()
  279. for cls in classes:
  280. training_set_files[cls], testing_set[cls] = split_vids_within_one_class(cls, path_to_dir)
  281.  
  282. freq = find_freq(training_set_files)
  283. plot_freq(freq)
  284. training_set_pictures = create_images_set(training_set_files)
  285. training_set_pictures = format_set(training_set_pictures, size)
  286. if augument_flag:
  287. augment(training_set_pictures, freq)
  288. freq = find_freq(training_set_pictures)
  289. plot_freq(freq)
  290. training_set_pictures = normalize_and_ravel_set(training_set_pictures)
  291. validation_set_pictures = create_images_set(testing_set)
  292. validation_set_pictures = format_set(validation_set_pictures, size)
  293. validation_set_pictures = normalize_and_ravel_set(validation_set_pictures)
  294. trainig_data, training_labels = stack_vectors_prepare_labels(training_set_pictures)
  295. validation_data, validation_labels = stack_vectors_prepare_labels(validation_set_pictures)
  296. return trainig_data, training_labels, validation_data, validation_labels
  297.  
  298.  
  299. def main():
  300. sizes = [(5, 5), (15, 15), (30, 30), (50, 50), (100, 100)]
  301. for augument_flag in [True, False]:
  302. to_plot_acc = []
  303. to_plot_time = []
  304. for size in sizes:
  305. begin = time.monotonic()
  306. print(size, augument_flag)
  307. data_tr, labels_tr, data_val, labels_val = prepare_data("../Images/", size, augument_flag)
  308. clf = RandomForestClassifier(n_jobs=1, random_state=0)
  309. clf.fit(data_tr, labels_tr)
  310. guessed_data = clf.predict(data_val)
  311. print(labels_val)
  312. print(guessed_data)
  313.  
  314. guesses_ok = 0
  315. for a, b in zip(guessed_data, labels_val):
  316. if a == b:
  317. guesses_ok = guesses_ok + 1
  318.  
  319. to_plot_acc.append(guesses_ok / len(guessed_data))
  320. total = time.monotonic() - begin
  321. to_plot_time.append(total)
  322. plt.figure()
  323. plt.title(f"Time vs. Image size with augmented: {augument_flag}")
  324. plt.xlabel(f"sizes (px)")
  325. plt.ylabel(f"time (sec)")
  326. plt.plot(sizes, to_plot_time)
  327. plt.figure()
  328. plt.title(f"Accuracy vs. Image size with augmented: {augument_flag}")
  329. plt.xlabel(f"sizes (px)")
  330. plt.ylabel(f"accuracy")
  331. plt.plot(sizes, to_plot_acc)
  332. plt.show()
  333.  
  334.  
  335. def batch_main():
  336. sizes = [(5, 5), (15, 15), (30, 30), (50, 50), (100, 100)]
  337. for augument_flag in [True, False]:
  338. to_plot_acc = []
  339. to_plot_time = []
  340. for size in sizes:
  341. begin = time.monotonic()
  342. print(size, augument_flag)
  343. training_gen, data_val, labels_val = prepare_batch_data("../Images/",
  344. size,
  345. augument_flag)
  346.  
  347. clf = RandomForestClassifier(n_jobs=1, random_state=0,
  348. warm_start=True, n_estimators=0)
  349.  
  350. n_estimators = 10
  351. for data_tr, labels_tr in training_gen():
  352. clf.set_params(n_estimators=n_estimators)
  353. clf.fit(data_tr, labels_tr)
  354. n_estimators += 10
  355.  
  356. guessed_data = clf.predict(data_val)
  357. print(labels_val)
  358. print(guessed_data)
  359.  
  360. guesses_ok = 0
  361. for a, b in zip(guessed_data, labels_val):
  362. if a == b:
  363. guesses_ok = guesses_ok + 1
  364.  
  365. to_plot_acc.append(guesses_ok / len(guessed_data))
  366. total = time.monotonic() - begin
  367. to_plot_time.append(total)
  368. plt.figure()
  369. plt.title(f"Time vs. Image size with augmented: {augument_flag}")
  370. plt.xlabel(f"sizes (px)")
  371. plt.ylabel(f"time (sec)")
  372. plt.plot(sizes, to_plot_time)
  373. plt.figure()
  374. plt.title(f"Accuracy vs. Image size with augmented: {augument_flag}")
  375. plt.xlabel(f"sizes (px)")
  376. plt.ylabel(f"accuracy")
  377. plt.plot(sizes, to_plot_acc)
  378. plt.show()
  379.  
  380.  
  381. batch_main()
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement