Guest User

Untitled

a guest
Oct 18th, 2019
151
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
  1. # -*- coding: utf-8 -*-
  2. # Hossam Amer
  3. # Run using this way: python3 visualize_featureMaps.py
  4.  
  5. # Inception image recognition attempt v1
  6. import tensorflow as tf
  7. import numpy as np
  8. import re
  9. import os
  10. import time
  11. from tkinter import *
  12. import tkinter.filedialog
  13. import matplotlib.pyplot as plt
  14. import logging
  15.  
  16. # Video capture and convert rgb
  17. from video_capture import VideoCaptureYUV
  18. import cv2
  19.  
  20. # Node look up
  21. from node_lookup import NodeLookup
  22.  
  23. import time
  24.  
  25. # for fetching files
  26. import glob
  27.  
  28. import math
  29.  
  30. from random import randrange
  31. import errno
  32.  
  33. ## CODE FROM LAOD DATA
  34. # path_jpeg = '/Volumes/work/workspace/Visualization_analysis_jpg_hevc/visualize/visualize_inception_featureMaps/all_new_graphs/PSNR_point_of_view'
  35. # path_hevc = '/Volumes/work/workspace/Visualization_analysis_jpg_hevc/visualize/visualize_inception_featureMaps/all_new_graphs/PSNR_point_of_view'
  36.  
  37. path_jpeg = './all_new_graphs/PSNR_Point_View/'
  38. path_hevc = './all_new_graphs/PSNR_Point_View/'
  39.  
  40.  
  41. import numpy as np
  42. import os
  43. import sys
  44.  
  45. flag = 0
  46. ## get idx and bin num from input
  47.  
  48. #img_idx =  int(sys.argv[1])
  49. img_idx = imgID = int(sys.argv[1])
  50.  
  51. if flag :
  52.     bin_num = 75
  53.     jpg_qf_idx =  int(sys.argv[2])
  54.     hevc_qp_idx = int(sys.argv[3])
  55.     from openpyxl import load_workbook
  56.     path_to_file =  '/home/h2amer/work/workspace/Visualization_analysis_jpg_hevc/visualize/visualize_inception_featureMaps/IV3-Qp-All_1_50000_HEVC.xlsx'
  57.     wb = load_workbook(filename=path_to_file, read_only=True, data_only=True)
  58.     ws = wb['Sheet1']
  59.  
  60.     sheet = wb.get_sheet_by_name('Sheet1')
  61.     N = 50000
  62.     max_row_limit = N
  63.     rank_hevc = np.zeros((N , 27))
  64.     rank_jpg = np.zeros((N,21))
  65.        
  66.     # np.save('rank_jpg', rank_jpg )
  67.     RANK_HEVC = np.load('rank_hevc.npy')
  68.     RANK_JPG = np.load('rank_jpg.npy')
  69.     print ('img id is' , img_idx)
  70.     print('hevc rank ' , RANK_HEVC[ img_idx - 1 , hevc_qp_idx] )
  71.     print('jpg rank ' , RANK_JPG[ img_idx - 1 , jpg_qf_idx] )
  72.     print('jpg qf_idx is' , jpg_qf_idx )
  73.  
  74. else:
  75.  
  76.     bin_num =  int(sys.argv[2])
  77.  
  78.     QP_idx_jpg = np.load(os.path.join( path_jpeg, 'QP_idx_jpg.npy'))
  79.     QP_idx_hevc = np.load(os.path.join( path_hevc , 'QP_idx_hevc.npy'))
  80.     img_qfs_hevc  = QP_idx_hevc[: , bin_num]
  81.     img_qfs_jpg  = QP_idx_jpg[: , bin_num]
  82.  
  83.     N = 50000
  84.     max_row_limit = N
  85.     rank_hevc = np.zeros((N , 27))
  86.     rank_jpg = np.zeros((N,21))
  87.        
  88.     # np.save('rank_jpg', rank_jpg )
  89.     RANK_HEVC = np.load('rank_hevc.npy')
  90.     RANK_JPG = np.load('rank_jpg.npy')
  91.     print ('img id is' , img_idx)
  92.     hevc_qp_idx = int( img_qfs_hevc[img_idx -1] )
  93.     jpg_qf_idx = int( img_qfs_jpg[img_idx -1])
  94.     imgID = img_idx
  95.     print('hevc rank ' , RANK_HEVC[ img_idx - 1 , int(img_qfs_hevc[img_idx -1]) ])
  96.     print('jpg rank ' , RANK_JPG[ img_idx - 1 , int(img_qfs_jpg[img_idx -1] )] )
  97.     print('jpg qf_idx is' , img_qfs_jpg[img_idx -1] )
  98.     #################################################################################3
  99.  
  100.  
  101.  
  102.  
  103. # needs more work
  104. #MODEL_PATH = '/Users/hossam.amer/7aS7aS_Works/work/jpeg_ml_research/inceptionv3/inception_model'
  105. MODEL_PATH = './inception_model'
  106.  
  107. # # YUV Path
  108. # PATH_TO_RECONS = '/Volumes/MULTICOMHD2/set_yuv/Seq-RECONS/'
  109.  
  110. # # JPEG Path
  111. # path_to_valid_images    = '/Volumes/MULTICOMHD2/validation_original/';
  112. # path_to_valid_QF_images = '/Volumes/MULTICOMHD2/validation_generated_QF/';
  113.  
  114. MAIN_PATH    = '/Volumes/MULTICOM102/103_HA/MULTICOM103/set_yuv/'
  115. # YUV Path
  116. PATH_TO_RECONS =  os.path.join(MAIN_PATH, 'Seq-RECONS-ffmpeg/')
  117.  
  118. # '/Volumes/MULTICOMHD2/set_yuv/Seq-RECONS/'
  119.  
  120. # JPEG Path
  121. path_to_valid_images    = '/media/h2amer/ADATA HD710/validation_generated_QF_0_5_100/'
  122.  
  123. #'/Volumes/MULTICOMHD2/validation_original/';
  124.  
  125.  
  126. #path_to_valid_QF_images = '/media/h2amer/ADATA HD710/validation_generated_QF_0_5_100/'
  127. # path_to_valid_QF_images    = '/media/h2amer/MULTICOM101/jpeg_data/validation_generated_QF/'
  128. #'/Volumes/MULTICOMHD2/validation_generated_QF/';
  129.  
  130. # path_to_valid_QF_images    = '/Volumes/MULTICOM101/jpeg_data/validation_generated_QF/'
  131. path_to_valid_QF_images    = '/Volumes/MULTICOM-104/validation_generated_QF/'
  132.  
  133.  
  134. # Main
  135.  
  136. # Print ops:
  137. # print_ops(sess)
  138. path = '/home/h2amer/work/workspace/Visualization_analysis_jpg_hevc/visualize/visualize_inception_featureMaps/analysis2/'
  139.  
  140. layerID = int(sys.argv[3])
  141. featureMapIdx =  -73
  142. isGrayScaleNorm = True
  143.  
  144.  
  145.  
  146. #读取训练好的Inception-v3模型来创建graph
  147. def create_graph():
  148.   # the class that's been created from the textual definition in graph.proto
  149.   #with tf.gfile.FastGFile('./inception_model/inception_v3_2016_08_28_frozen.pb', 'rb') as f:  
  150.     with tf.gfile.FastGFile(MODEL_PATH + '/classify_image_graph_def.pb', 'rb') as f:  
  151.         graph_def = tf.GraphDef()
  152.         graph_def.ParseFromString(f.read())
  153.         tf.import_graph_def(graph_def, name='')
  154.  
  155.  
  156. def print_ops(sess):
  157.   constant_ops = [op for op in sess.graph.get_operations() if op.type == "Const"]
  158.   for constant_op in constant_ops:
  159.     print(constant_op.name)  
  160.  
  161. os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3' # Hide the warning information from Tensorflow - annoying...
  162.  
  163.  
  164.  
  165.  
  166.  
  167.  
  168.  
  169. def show_image(imgID, QF, image_data, isCast = True):
  170.     # Parse the YUV and convert it into RGB
  171.     # original_img_ID = imgID
  172.     # imgID = str(imgID).zfill(8)
  173.     # shard_num  = round(original_img_ID/10000);
  174.     # folder_num = math.ceil(original_img_ID/1000);
  175.     original_img_ID = imgID
  176.     print('SHOW IMAGE: ', isCast)
  177.     imgID = str(imgID).zfill(8)
  178.     shard_num  = math.floor((original_img_ID - 1) / 10000)
  179.     folder_num = math.ceil(original_img_ID/1000)+1;
  180.     if (((original_img_ID-1)/1000.0)==folder_num-1):
  181.         folder_num = (original_img_ID-1)/1000
  182.    
  183.     if (folder_num == original_img_ID/1000 ):
  184.         folder_num = folder_num + 1
  185.     if ((folder_num-1)*1000==original_img_ID):
  186.         folder_num = folder_num - 1
  187.  
  188.     if not isCast:
  189.         if QF == 110:
  190.             image = path_to_valid_images + str(folder_num) + '/ILSVRC2012_val_' + imgID + '.JPEG'
  191.             figure_title = 'ILSVRC2012_val_' + imgID + '.JPEG'
  192.         else:
  193.             # shard_num = math.floor((original_img_ID - 1) / 10000)
  194.             # folder_num = math.ceil(original_img_ID/1000)
  195.             shard_num = math.floor((original_img_ID - 1) / 10000)
  196.             folder_num = math.ceil(original_img_ID/1000)+1;
  197.             if (((original_img_ID-1)/1000.0)==folder_num-1):
  198.                 folder_num = (original_img_ID-1)/1000
  199.                 print('here1')
  200.             if (folder_num == original_img_ID/1000 ):
  201.                 folder_num = folder_num + 1
  202.             if ((folder_num-1)*1000==original_img_ID):
  203.                 folder_num = folder_num - 1
  204.             #image = path_to_valid_QF_images + str(folder_num) + '/ILSVRC2012_val_' + imgID + '-QF-' + str(QF) + '.JPEG'
  205.             image = path_to_valid_QF_images + 'shard-' + str(int(shard_num)) + '/' + str(int(folder_num)) + '/ILSVRC2012_val_' + imgID + '-QF-' + str(QF) + '.JPEG'
  206.             figure_title = 'ILSVRC2012_val_' + imgID + '-QF-' + str(QF) + '.JPEG'
  207.  
  208.             print('Show image JPEG: ', image)
  209.             image_data = cv2.imread(image)
  210.  
  211.  
  212.         print('Save JPEG')
  213.         filename = './'+imgID+'/ILSVRC2012_val_' + imgID + '-QF-' + str(QF) +'.bmp'
  214.         savefile_ex(filename)
  215.         cv2.imwrite(filename, image_data)
  216.         # cv2.imwrite('/Users/ahamsala/Documents/7.visualize_code/Visualization_analysis_jpg_hevc/image1.bmp', image_data)
  217.         print(image)
  218.     else:
  219.         path_to_recons = PATH_TO_RECONS
  220.         # Get files list to fetch the correct name
  221.         print('PATH TO GLOB: ', path_to_recons + str(int(folder_num)) + '/ILSVRC2012_val_' + imgID + '*.yuv')
  222.         filesList = glob.glob(path_to_recons + str(int(folder_num)) + '/ILSVRC2012_val_' + imgID + '*.yuv')
  223.         name = filesList[0].split('/')[-1]
  224.         rgbStr = name.split('_')[5]
  225.         width  = int(name.split('_')[-4])
  226.         height = int(name.split('_')[-3])
  227.         is_gray_str = name.split('_')[-2]
  228.         figure_title = 'ILSVRC2012_val_' + imgID + '_' + str(width) + '_' + str(height) + '_' + rgbStr + '_' + str(QF) + '_1.yuv'
  229.         image = path_to_recons + str(int(folder_num)) + '/ILSVRC2012_val_' + imgID + '_' + str(width) + '_' + str(height) + '_' + rgbStr + '_' + str(QF) + '.yuv'
  230.         print('Save: ', image)
  231.         size = (height, width)
  232.         videoObj = VideoCaptureYUV(image, size, isGrayScale=is_gray_str.__contains__('Y'))
  233.         ret, yuv, rgb = videoObj.getYUVAndRGB()
  234.         image_data = rgb
  235.  
  236.         print('Save HEVC')
  237.         filename = './'+imgID+'/ILSVRC2012_val_' + imgID + '_' + str(width) + '_' + str(height) + '_' + rgbStr + '_' + str(QF) +'.bmp'
  238.         savefile_ex(filename)
  239.         cv2.imwrite(filename, image_data)
  240.         print(image)
  241.  
  242.     # plt.figure(100*QF)
  243.     # plt.imshow(image_data)
  244.     # plt.suptitle(figure_title, fontsize=16)
  245.  
  246.  
  247.  
  248.  
  249. def get_image_data(imgID, QF, isCast = True):
  250.     # Parse the YUV and convert it into RGB
  251.     original_img_ID = imgID
  252.     imgID = str(imgID).zfill(8)
  253.     shard_num  = math.floor((original_img_ID - 1) / 10000)
  254.     folder_num = math.ceil(original_img_ID/1000)+1;
  255.     if (((original_img_ID-1)/1000.0)==folder_num-1):
  256.         folder_num = (original_img_ID-1)/1000
  257.         print('here1')
  258.     if (folder_num == original_img_ID/1000 ):
  259.         folder_num = folder_num + 1
  260.     if ((folder_num-1)*1000==original_img_ID):
  261.         folder_num = folder_num - 1
  262.  
  263.  
  264.     # shard_num  = round(original_img_ID/10000);
  265.     # folder_num = math.ceil(original_img_ID/1000)+1;
  266.     if isCast:
  267.         path_to_recons = PATH_TO_RECONS
  268.         # Get files list to fetch the correct name
  269.         filesList = glob.glob(path_to_recons + str(int(folder_num)) + '/ILSVRC2012_val_' + imgID + '*.yuv')
  270.         print(path_to_recons + str(int(folder_num)) + '/ILSVRC2012_val_' + imgID + '*.yuv')
  271.         name = filesList[0].split('/')[-1]
  272.         rgbStr = name.split('_')[5]
  273.         width  = int(name.split('_')[-4])
  274.         height = int(name.split('_')[-3])
  275.         is_gray_str = name.split('_')[-2]
  276.        
  277.         image = path_to_recons + str(int(folder_num)) + '/ILSVRC2012_val_' + imgID + '_' + str(width) + '_' + str(height) + '_' + rgbStr + '_' + str(QF) + '.yuv'
  278.         figure_title = 'ILSVRC2012_val_' + imgID + '_' + str(width) + '_' + str(height) + '_' + rgbStr + '_' + str(QF) + '_1.yuv'
  279.         print(image)
  280.         size = (height, width) # height and then width
  281.         videoObj = VideoCaptureYUV(image, size, isGrayScale=is_gray_str.__contains__('Y'))
  282.         ret, yuv, rgb = videoObj.getYUVAndRGB()
  283.         image_data = rgb
  284.  
  285.     else:
  286.         if QF == 110:
  287.             image = path_to_valid_images + str(folder_num) + '/ILSVRC2012_val_' + imgID + '.JPEG'
  288.             figure_title = 'ILSVRC2012_val_' + imgID + '.JPEG'
  289.         else:
  290.             # shard_num = math.floor((original_img_ID - 1) / 10000)
  291.             # folder_num = math.ceil(original_img_ID/1000)
  292.             shard_num = math.floor((original_img_ID - 1) / 10000)
  293.             folder_num = math.ceil(original_img_ID/1000)+1;
  294.             if (((original_img_ID-1)/1000.0)==folder_num-1):
  295.                 folder_num = (original_img_ID-1)/1000
  296.                 print('here1')
  297.             if (folder_num == original_img_ID/1000 ):
  298.                 folder_num = folder_num + 1
  299.             if ((folder_num-1)*1000==original_img_ID):
  300.                 folder_num = folder_num - 1
  301.             #image = path_to_valid_QF_images + str(folder_num) + '/ILSVRC2012_val_' + imgID + '-QF-' + str(QF) + '.JPEG'
  302.             image = path_to_valid_QF_images + 'shard-' + str(int(shard_num)) + '/' + str(int(folder_num)) + '/ILSVRC2012_val_' + imgID + '-QF-' + str(QF) + '.JPEG'
  303.             figure_title = 'ILSVRC2012_val_' + imgID + '-QF-' + str(QF) + '.JPEG'
  304.         print(image)
  305.         image_data = tf.gfile.FastGFile(image, 'rb').read()
  306.     return image_data, figure_title
  307.  
  308.  
  309.  
  310. def plot(feature_maps, featureMapIdx, figure_title, isCast, imgID, bin_num):
  311.     K     = feature_maps.shape[2]
  312.     nRows = K//8
  313.     nCols = K//nRows
  314.     if featureMapIdx < 0:
  315.         fig, ax = plt.subplots(nrows=nRows, ncols=nCols, figsize=(10, 5))
  316.         plt.figure(figureID)
  317.         for irow, row in enumerate(ax):
  318.             for icol, col in enumerate(row):
  319.                 idx = irow + icol * nRows
  320.                 if idx >= K:
  321.                     continue
  322.  
  323.                 m = feature_maps[:, :, idx]
  324.                 if isGrayScaleNorm:
  325.                     A   = np.double(m)
  326.                     out = np.zeros(A.shape, np.double)
  327.                     m   = cv2.normalize(A, out, 255.0, 0.0, cv2.NORM_MINMAX)
  328.                     col.imshow(m, cmap='gray', vmin=0.0, vmax=255.0)
  329.                 else:
  330.                     col.imshow(m)
  331.                 col.axis('off')
  332.     else:
  333.         plt.figure(figureID)
  334.         m = feature_maps[:, :, featureMapIdx]
  335.         if isGrayScaleNorm:
  336.             A   = np.double(m)
  337.             out = np.zeros(A.shape, np.double)
  338.             m   = cv2.normalize(A, out, 255.0, 0.0, cv2.NORM_MINMAX)
  339.             plt.imshow(m, cmap='gray', vmin=0.0, vmax=255.0)
  340.         else:
  341.             plt.imshow(m)
  342.             plt.axis('off')
  343.         figure_title = str(featureMapIdx) + ')' + figure_title
  344.     plt.suptitle(figure_title, fontsize=16)
  345.     plt.subplots_adjust(wspace=0.0, hspace=0.0)
  346.    
  347.     if isCast:
  348.         filename = './'+str(imgID).zfill(8)+'/layerID_'+str(layerID)+'_'+str(bin_num)+'_'+str(imgID)+'/'+str(imgID) + '_' + str(bin_num) + '_hevc.png'
  349.         savefile_ex(filename)
  350.         plt.savefig(filename, dpi=600)
  351.         # plt.savefig(str(imgID) + '_' + str(bin_num) + '_hevc.png', dpi=600)
  352.     else:
  353.         filename = './'+str(imgID).zfill(8)+'/layerID_'+str(layerID)+'_'+str(bin_num)+'_'+str(imgID)+'/'+str(imgID) + '_' + str(bin_num) + '_jpg.png'
  354.         savefile_ex(filename)
  355.         plt.savefig(filename, dpi=600)
  356.         # plt.savefig(str(imgID) + '_' +str(bin_num) + '_jpg.png', dpi=600)
  357.  
  358. def savefile_ex(filename):
  359.     if not os.path.exists(os.path.dirname(filename)):
  360.         try:
  361.             os.makedirs(os.path.dirname(filename))
  362.         except OSError as exc: # Guard against race condition
  363.             if exc.errno != errno.EEXIST:
  364.                 raise
  365.  
  366. # Visualizes feature map of a specific image in the validation set
  367. def visualize_image(imgID, QF, layerID = 1, figureID = 1, isCast = True, isGrayScaleNorm = False, featureMapIdx = -1, bin_num = 0):
  368.  
  369.  
  370.     create_graph()
  371.     config = tf.ConfigProto(device_count = {'GPU': 0})
  372.     #sess = tf.Session()
  373.     sess = tf.Session(config=config)
  374.  
  375.     # Inception-v3: last layer is output as softmax
  376.     conv1_tensor = sess.graph.get_tensor_by_name('conv_' + str(layerID) + ':0')
  377.  
  378.     # Inception-v3: get the softmax tensor
  379.     softmax_tensor = sess.graph.get_tensor_by_name('softmax:0')
  380.  
  381.  
  382.  
  383.  
  384.     # Title of the figure
  385.     figure_title = ''
  386.  
  387.     # Get image data
  388.  
  389.     image_data, figure_title = get_image_data(imgID, QF, isCast)
  390.  
  391.     #Show the image
  392.     show_image(imgID, QF, image_data, isCast)
  393.  
  394.     if isCast:
  395.         feature_maps = sess.run(conv1_tensor, {'Cast:0': image_data}) # n, m, 3
  396.         predictions = sess.run(softmax_tensor,{'Cast:0': image_data}) # n, m, 3
  397.     else:
  398.         feature_maps = sess.run(conv1_tensor, {'DecodeJpeg/contents:0': image_data}) # n, m, 3
  399.         predictions = sess.run(softmax_tensor,{'DecodeJpeg/contents:0': image_data}) # n, m, 3
  400.  
  401.     predictions = np.squeeze(predictions)
  402.  
  403.     # ID --> English string label.
  404.     node_lookup = NodeLookup()
  405.     N = -1008
  406.  
  407.     # Current_rank = -1
  408.     current_rank = -1
  409.  
  410.  
  411.     #(top-5)
  412.     top_5 = predictions.argsort()[N:][::-1]
  413.     for rank, node_id in enumerate(top_5):
  414.         human_string = node_lookup.id_to_string(node_id)
  415.         score = predictions[node_id]
  416.         # if rank < 5:
  417.         #     print('%d: %s (score = %.5f)' % (1 + rank, human_string, score))
  418.  
  419.         # if(gt_label_list[idx+1] == human_string):
  420.         #   print('%d: %s (score = %.5f)' % (1 + rank, human_string, score))
  421.     for idx1, rank_top5 in zip(range(1,6), top_5):
  422.         print('isHEVC: %d, Top-5 -- Node_ID: %d : %s (score = %.20f)' % (int(isCast), int(rank_top5), node_lookup.id_to_string(rank_top5), float(predictions[rank_top5])))
  423.  
  424.     # print(type(feature_maps))
  425.     #print(feature_maps.shape)
  426.  
  427.     feature_maps = np.reshape(feature_maps, [feature_maps.shape[1], feature_maps.shape[2], feature_maps.shape[3]])
  428.     plot(feature_maps, featureMapIdx, figure_title, isCast, imgID, bin_num)
  429.  
  430.     print('GrayScale: ', isGrayScaleNorm)
  431.     print ('Layer ID: %d' % layerID)
  432.     if featureMapIdx > 0:
  433.         print('Feature Map Index: %d' % featureMapIdx)
  434.        
  435.     print('\n')
  436.  
  437. QP = []
  438. QP.append(51)
  439. for i in range(50, -2 , -2):
  440.     QP.append(i)
  441.  
  442. QF = QP[hevc_qp_idx]
  443. figureID = 1
  444. # print('here',imgID)
  445. visualize_image(imgID, QF, layerID, figureID, True, isGrayScaleNorm ,featureMapIdx, bin_num )
  446.  
  447. QF = [i for i in range(0,100,5)]
  448. QF = QF[jpg_qf_idx]
  449. figureID = figureID + 1
  450. visualize_image(imgID, QF, layerID, figureID, False, isGrayScaleNorm , featureMapIdx, bin_num)
  451.  
  452. # QF = 110
  453. # figureID = figureID + 1
  454. # visualize_image(imgID, QF, layerID, figureID, False, isGrayScaleNorm, featureMapIdx)
  455.  
  456.  
  457. # QF = 10
  458. # figureID = figureID + 1
  459. # visualize_image(imgID, QF, layerID, figureID, False, isGrayScaleNorm)
  460.  
  461. # figureID = figureID + 1
  462. # visualize_image(imgID, QF, layerID, figureID, False, isGrayScaleNorm, featureMapIdx)
  463.  
  464.  
  465. # Show plt at the end
  466. plt.show()
RAW Paste Data