Advertisement
Guest User

Untitled

a guest
Dec 11th, 2019
169
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
text 15.16 KB | None | 0 0
  1. import cv2
  2. import os
  3. from Dataset_Utils.dataset_tools import findRelevantFace, enclosing_square
  4. from Dataset_Utils.facedetect_vggface2.face_detector import FaceDetector
  5. from Dataset_Utils.facedetect_vggface2.face_aligner import FaceAligner
  6. import dlib
  7. from tqdm import tqdm
  8. import numpy as np
  9. import joblib
  10. from keras.models import model_from_json
  11.  
  12. from Dataset.Dataset_Utils.dataset_tools import findFaceOnSide, cut
  13. from openface_model import create_model
  14.  
  15. BASE_PATH = os.path.dirname(os.path.abspath(__file__))
  16. cache_p = 'Aff_Wild_Cache/'
  17. input_p_ds = 'AFF-Wild/VA'
  18.  
  19.  
  20. def findCosineDistance(source_representation, test_representation):
  21. a = np.matmul(np.transpose(source_representation), test_representation)
  22. b = np.sum(np.multiply(source_representation, source_representation))
  23. c = np.sum(np.multiply(test_representation, test_representation))
  24. return 1 - (a / (np.sqrt(b) * np.sqrt(c)))
  25.  
  26.  
  27. def l2_normalize(x, axis=-1, epsilon=1e-10):
  28. output = x / np.sqrt(np.maximum(np.sum(np.square(x), axis=axis, keepdims=True), epsilon))
  29. return output
  30.  
  31.  
  32. def findEuclideanDistance(source_representation, test_representation):
  33. euclidean_distance = source_representation - test_representation
  34. euclidean_distance = np.sum(np.multiply(euclidean_distance, euclidean_distance))
  35. euclidean_distance = np.sqrt(euclidean_distance)
  36. # euclidean_distance = l2_normalize(euclidean_distance )
  37. return euclidean_distance
  38.  
  39. class AffWild_Dataset:
  40. gen = None
  41. partition = None
  42.  
  43. def __init__(self, partition='Training', input_path=input_p_ds, cache_path=cache_p, target_shape=(224, 224, 3),
  44. shuffle_samples=True, augment=True, custom_augmentation=None, debug_max_num_samples=None, split=False,
  45. split_len=16, cast_to_imgs=False, number_frames_for_cast=4):
  46. self.target_shape = target_shape
  47. self.custom_augmentation = custom_augmentation
  48. self.augment = augment
  49. self.split = split
  50. self.info = list()
  51. self.partition = partition
  52. import tensorflow as tf
  53.  
  54. print("loading model openface")
  55. self.model = create_model()
  56. self.model.load_weights("Dataset_Utils/openface_weights.h5")
  57. print('Loading data...')
  58.  
  59.  
  60.  
  61. cache_path = os.path.join(cache_p, partition)
  62. cache_file_name = '%s.%s.info' % ("aff_wild", partition)
  63. # if exist, read dataset info from file
  64. try:
  65. with open(os.path.join(cache_path, cache_file_name), 'rb') as f:
  66. self.info = joblib.load(f)
  67. if debug_max_num_samples is not None and debug_max_num_samples < len(self.info):
  68. self.info = self.data[:debug_max_num_samples]
  69.  
  70. print("Data loaded. %d samples, from cache" % (len(self.info)))
  71.  
  72. except FileNotFoundError:
  73. print('File not found,creating...')
  74. # read dataset starting from original video
  75. # check if cache folder samples exist
  76. if not os.path.isdir(cache_path):
  77. os.makedirs(cache_path)
  78. # if not, read all dataset info
  79. # read and process al video in cache
  80. print('Files not found,creating...')
  81.  
  82. self._create_map(partition, input_path)
  83. # save dataset info as .cache file
  84. print("doing backup on cache file")
  85. with open(os.path.join(cache_path, cache_file_name), 'wb') as f:
  86. joblib.dump(self.info, f)
  87.  
  88. def _create_map(self, partition, input_path_ds):
  89. fd = FaceDetector()
  90. fa = FaceAligner()
  91. annotation_path = input_path_ds + '/annotation' + '/' + partition
  92. video_path = input_path_ds + "/videos/" + partition
  93. # iterate over all video in dir
  94. for filename in tqdm(os.listdir(annotation_path)):
  95. print(filename)
  96. try:
  97. self._process_video(path=(annotation_path + "/" + filename), face_detector=fd, face_aligner=fa,
  98. time_step=16, video_dir_path=video_path)
  99.  
  100. except Exception as e:
  101. try:
  102. print(e)
  103. finally:
  104. e = None
  105. del e
  106.  
  107. def _get_annotations(self, path):
  108.  
  109. video_name = path.split('/')[(-1)].replace('.txt', '')
  110. annotations = {'normal': [None, True], 'right': [None, False], 'left': [None, False]}
  111.  
  112. if '_right' in video_name:
  113.  
  114. video_name = video_name.replace('_right', '')
  115. annotations['normal'][1] = False
  116. annotations['right'][1] = True
  117. elif '_left' in video_name:
  118. video_name = video_name.replace('_left', '')
  119. annotations['normal'][1] = False
  120. annotations['left'][1] = True
  121.  
  122. annotations_path = os.path.dirname(path)
  123. if os.path.isfile(annotations_path + "/" + video_name + '.txt'):
  124. annotations['normal'][0] = np.loadtxt((annotations_path + "/" + video_name + '.txt'), delimiter=',',
  125. skiprows=1)
  126.  
  127. if os.path.isfile(annotations_path + "/" + video_name + '_right.txt'):
  128. annotations['right'][0] = np.loadtxt((annotations_path + "/" + video_name + '_right.txt'), delimiter=',',
  129. skiprows=1)
  130. if os.path.isfile(annotations_path + "/" + video_name + '_left.txt'):
  131. annotations['left'][0] = np.loadtxt((annotations_path + "/" + video_name + '_left.txt'), delimiter=',',
  132. skiprows=1)
  133.  
  134. return annotations
  135.  
  136. def _process_video(self, path='', face_detector=None, face_aligner=None, time_step=16, video_dir_path=''):
  137.  
  138. video_name = path.split('/')[(-1)].replace('.txt', '')
  139. mode = 'normal'
  140.  
  141. annotations = self._get_annotations(path)
  142. if "_right" in video_name:
  143. video_name = video_name.replace('_right', '')
  144.  
  145. elif "_left" in video_name:
  146. video_name = video_name.replace('_left', '')
  147.  
  148. video_path = video_dir_path + "/" + video_name
  149. if os.path.isfile(video_path + '.mp4'):
  150. video_path = video_path + '.mp4'
  151. if os.path.isfile(video_path + '.avi'):
  152. video_path = video_path + '.avi'
  153.  
  154. cv2video = cv2.VideoCapture(video_path)
  155.  
  156. extra_annotation = None
  157. ann = None
  158. for k, v in annotations.items():
  159. if v[1]:
  160. ann = [k, v[0]]
  161.  
  162. elif v[0] is not None:
  163. extra_annotation = [k, v[0]]
  164.  
  165. info = self._init_map(video_name, cv2video, time_step, ann[1])
  166.  
  167. if (cv2video.isOpened() == False):
  168. print("Error opening video stream or file")
  169.  
  170. info['roi'], info['landmarks'], info['corrupted_frames'], info['indices'] = self._get_video_info(cv2video,
  171. face_detector,
  172. face_aligner,
  173. info, mode,
  174. time_step, ann,
  175. extra_annotation)
  176.  
  177. self.info.append(info)
  178.  
  179. def _init_map(self, video_name, cv2video, time_step, ann):
  180. print("Init info map")
  181. info = {
  182. 'video_name': video_name,
  183. 'total_frames': cv2video.get(cv2.CAP_PROP_FRAME_COUNT),
  184. 'fps': cv2video.get(cv2.CAP_PROP_FPS),
  185. 'height': cv2video.get(cv2.CAP_PROP_FRAME_HEIGHT),
  186. 'width': cv2video.get(cv2.CAP_PROP_FRAME_WIDTH),
  187. 'time_step': time_step,
  188. 'indices': list(),
  189. 'annotations': ann,
  190. 'duration': cv2video.get(cv2.CAP_PROP_FRAME_COUNT) / cv2video.get(cv2.CAP_PROP_FPS),
  191. 'landmarks': list(),
  192. 'roi': list(),
  193. 'corrupted_frames': list(),
  194. }
  195. print("End Init info map")
  196. return info
  197.  
  198. def _get_video_info(self, cv2video, face_detector, face_aligner, info, mode, time_step, current_annotation,
  199. extra_annotation):
  200. frame_counter = 0
  201. rois = list()
  202. landmarks = list()
  203. corrupted_frames = list()
  204. indices = list()
  205. first = True
  206. curr_bound = 0
  207. last_corr = 0
  208. print("l1: ",len(current_annotation[1]))
  209. print("l2: ",len(extra_annotation[1]))
  210. print("tot:",info['total_frames'])
  211.  
  212.  
  213. # Read until video is completedwhile(frame_counter < info['total_frames']):
  214. print("Processing video")
  215. # pbar = tqdm(total=info['total_frames'] + 1)
  216. prec_face = None
  217. while (frame_counter < info['total_frames']):
  218. # Capture frame-by-frame
  219. if True:
  220.  
  221. ret, frame = cv2video.read()
  222. if ret == True:
  223.  
  224. faces = face_detector.detect(frame)
  225.  
  226. if extra_annotation is None:
  227. f = findRelevantFace(faces, frame.shape[1], frame.shape[0])
  228. else:
  229. if len(faces) > 1:
  230. if (current_annotation[0] == 'right' or 'left') and (
  231. extra_annotation[0] == 'right' or 'left'):
  232. # both annotations (right left) present
  233. if current_annotation[0] == 'right':
  234. f = findFaceOnSide(faces, True, frame.shape[1])
  235. else:
  236. f = findFaceOnSide(faces, False, frame.shape[1])
  237.  
  238. elif current_annotation[0] == 'normal' or extra_annotation[0] == 'normal':
  239. if current_annotation[0] == 'normal':
  240. if extra_annotation[0] == 'right':
  241. f = findFaceOnSide(faces, False, frame.shape[1], True)
  242. else:
  243. f = findFaceOnSide(faces, True, frame.shape[1], True)
  244. else:
  245. if current_annotation[0] == 'right':
  246. f = findFaceOnSide(faces, True, frame.shape[1])
  247. else:
  248. f = findFaceOnSide(faces, False, frame.shape[1])
  249.  
  250. elif len(faces) == 1:
  251.  
  252.  
  253. current_frame_annotation = False if current_annotation[1][frame_counter][0] == 0 and \
  254. current_annotation[1][frame_counter][1] == 0 else True
  255. current_extra_frame_annotation = False if extra_annotation[1][frame_counter][0] == 0 and \
  256. extra_annotation[1][frame_counter][
  257. 1] == 0 else True
  258. #print(current_extra_frame_annotation)
  259. if current_extra_frame_annotation is False or current_frame_annotation is True:
  260.  
  261.  
  262. f = findRelevantFace(faces, frame.shape[1], frame.shape[0])
  263. else:
  264. f = None
  265.  
  266.  
  267. else:
  268. f = None
  269.  
  270. if (f is not None) and (f['img'].size != 0):
  271. tmp_roi = enclosing_square(f['roi'])
  272. tmp_img = cut(frame,tmp_roi)
  273. tmp_img = cv2.resize(tmp_img,(96,96))
  274. if prec_face is None:
  275.  
  276.  
  277. prec_face = tmp_img
  278.  
  279.  
  280. if frame_counter > 0:
  281.  
  282. if current_annotation[0] == 'right':
  283. cv2.imshow("fimg",f['img'])
  284. cv2.imshow("tmp",tmp_img)
  285. cv2.imshow("prec",prec_face)
  286. cv2.waitKey(0)
  287.  
  288.  
  289.  
  290. img1_representation = self.model.predict(np.expand_dims(prec_face,axis=0))[0,:]
  291. img2_representation = self.model.predict(np.expand_dims(tmp_img,axis=0))[0,:]
  292.  
  293. cosine = findCosineDistance(img1_representation, img2_representation)
  294. euclidean = findEuclideanDistance(img1_representation, img2_representation)
  295.  
  296. prec_face = tmp_img
  297.  
  298. print("cosine: ",cosine)
  299. print("eucl: ",euclidean)
  300.  
  301. f['roi'] = enclosing_square(f['roi'])
  302. # f['roi'] = add_margin(f['roi'], 0.2)
  303. detections = dlib.full_object_detections()
  304. detections.append(
  305. face_aligner.get_shape_detections(frame, dlib.rectangle(f['roi'][0], f['roi'][1],
  306. f['roi'][0] + f['roi'][2],
  307. f['roi'][1] + f['roi'][
  308. 3])))
  309. rois.append(f['roi'])
  310. landmarks.append(detections)
  311.  
  312. # cv2.imshow("left", f['img'])
  313. # cv2.waitKey(0)
  314. else:
  315. corrupted_frames.append(frame_counter)
  316. rois.append(None)
  317. landmarks.append(None)
  318. # Break the loop
  319. else:
  320. print("ret False")
  321. corrupted_frames.append(frame_counter)
  322. rois.append(None)
  323. landmarks.append(None)
  324.  
  325. if frame_counter - len(corrupted_frames) > 0:
  326. if (frame_counter - len(corrupted_frames)) % time_step == 0:
  327. if (frame_counter - curr_bound) >= len(corrupted_frames) - last_corr + time_step:
  328. indices.append((curr_bound, frame_counter - 1))
  329. last_corr = len(corrupted_frames)
  330. curr_bound = frame_counter
  331. else:
  332. last_corr = len(corrupted_frames)
  333. curr_bound = frame_counter
  334.  
  335. frame_counter += 1
  336. # pbar.update(1)
  337. # pbar.close()
  338.  
  339. # When everything done, release the video capture object
  340. cv2video.release()
  341. # Closes all the frames
  342. cv2.destroyAllWindows()
  343.  
  344. return (rois, landmarks, corrupted_frames, indices)
  345.  
  346.  
  347. if __name__ == "__main__":
  348. AffWild_Dataset()
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement