SHARE
TWEET

Untitled

a guest Oct 18th, 2019 103 Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
  1. # -*- coding: utf-8 -*-
  2. # Hossam Amer
  3. # Run using this way: python3 visualize_featureMaps.py
  4.  
  5. # Inception image recognition attempt v1
  6. import tensorflow as tf
  7. import numpy as np
  8. import re
  9. import os
  10. import time
  11. from tkinter import *
  12. import tkinter.filedialog
  13. import matplotlib.pyplot as plt
  14. import logging
  15.  
  16. # Video capture and convert rgb
  17. from video_capture import VideoCaptureYUV
  18. import cv2
  19.  
  20. # Node look up
  21. from node_lookup import NodeLookup
  22.  
  23. import time
  24.  
  25. # for fetching files
  26. import glob
  27.  
  28. import math
  29.  
  30. from random import randrange
  31. import errno
  32.  
  33. ## CODE FROM LAOD DATA
  34. # path_jpeg = '/Volumes/work/workspace/Visualization_analysis_jpg_hevc/visualize/visualize_inception_featureMaps/all_new_graphs/PSNR_point_of_view'
  35. # path_hevc = '/Volumes/work/workspace/Visualization_analysis_jpg_hevc/visualize/visualize_inception_featureMaps/all_new_graphs/PSNR_point_of_view'
  36.  
  37. path_jpeg = './all_new_graphs/PSNR_Point_View/'
  38. path_hevc = './all_new_graphs/PSNR_Point_View/'
  39.  
  40.  
  41. import numpy as np
  42. import os
  43. import sys
  44.  
  45. flag = 0
  46. ## get idx and bin num from input
  47.  
  48. #img_idx =  int(sys.argv[1])
  49. img_idx = imgID = int(sys.argv[1])
  50.  
  51. if flag :
  52.     bin_num = 75
  53.     jpg_qf_idx =  int(sys.argv[2])
  54.     hevc_qp_idx = int(sys.argv[3])
  55.     from openpyxl import load_workbook
  56.     path_to_file =  '/home/h2amer/work/workspace/Visualization_analysis_jpg_hevc/visualize/visualize_inception_featureMaps/IV3-Qp-All_1_50000_HEVC.xlsx'
  57.     wb = load_workbook(filename=path_to_file, read_only=True, data_only=True)
  58.     ws = wb['Sheet1']
  59.  
  60.     sheet = wb.get_sheet_by_name('Sheet1')
  61.     N = 50000
  62.     max_row_limit = N
  63.     rank_hevc = np.zeros((N , 27))
  64.     rank_jpg = np.zeros((N,21))
  65.        
  66.     # np.save('rank_jpg', rank_jpg )
  67.     RANK_HEVC = np.load('rank_hevc.npy')
  68.     RANK_JPG = np.load('rank_jpg.npy')
  69.     print ('img id is' , img_idx)
  70.     print('hevc rank ' , RANK_HEVC[ img_idx - 1 , hevc_qp_idx] )
  71.     print('jpg rank ' , RANK_JPG[ img_idx - 1 , jpg_qf_idx] )
  72.     print('jpg qf_idx is' , jpg_qf_idx )
  73.  
  74. else:
  75.  
  76.     bin_num =  int(sys.argv[2])
  77.  
  78.     QP_idx_jpg = np.load(os.path.join( path_jpeg, 'QP_idx_jpg.npy'))
  79.     QP_idx_hevc = np.load(os.path.join( path_hevc , 'QP_idx_hevc.npy'))
  80.     img_qfs_hevc  = QP_idx_hevc[: , bin_num]
  81.     img_qfs_jpg  = QP_idx_jpg[: , bin_num]
  82.  
  83.     N = 50000
  84.     max_row_limit = N
  85.     rank_hevc = np.zeros((N , 27))
  86.     rank_jpg = np.zeros((N,21))
  87.        
  88.     # np.save('rank_jpg', rank_jpg )
  89.     RANK_HEVC = np.load('rank_hevc.npy')
  90.     RANK_JPG = np.load('rank_jpg.npy')
  91.     print ('img id is' , img_idx)
  92.     hevc_qp_idx = int( img_qfs_hevc[img_idx -1] )
  93.     jpg_qf_idx = int( img_qfs_jpg[img_idx -1])
  94.     imgID = img_idx
  95.     print('hevc rank ' , RANK_HEVC[ img_idx - 1 , int(img_qfs_hevc[img_idx -1]) ])
  96.     print('jpg rank ' , RANK_JPG[ img_idx - 1 , int(img_qfs_jpg[img_idx -1] )] )
  97.     print('jpg qf_idx is' , img_qfs_jpg[img_idx -1] )
  98.     #################################################################################3
  99.  
  100.  
  101.  
  102.  
  103. # needs more work
  104. #MODEL_PATH = '/Users/hossam.amer/7aS7aS_Works/work/jpeg_ml_research/inceptionv3/inception_model'
  105. MODEL_PATH = './inception_model'
  106.  
  107. # # YUV Path
  108. # PATH_TO_RECONS = '/Volumes/MULTICOMHD2/set_yuv/Seq-RECONS/'
  109.  
  110. # # JPEG Path
  111. # path_to_valid_images    = '/Volumes/MULTICOMHD2/validation_original/';
  112. # path_to_valid_QF_images = '/Volumes/MULTICOMHD2/validation_generated_QF/';
  113.  
  114. MAIN_PATH    = '/Volumes/MULTICOM102/103_HA/MULTICOM103/set_yuv/'
  115. # YUV Path
  116. PATH_TO_RECONS =  os.path.join(MAIN_PATH, 'Seq-RECONS-ffmpeg/')
  117.  
  118. # '/Volumes/MULTICOMHD2/set_yuv/Seq-RECONS/'
  119.  
  120. # JPEG Path
  121. path_to_valid_images    = '/media/h2amer/ADATA HD710/validation_generated_QF_0_5_100/'
  122.  
  123. #'/Volumes/MULTICOMHD2/validation_original/';
  124.  
  125.  
  126. #path_to_valid_QF_images = '/media/h2amer/ADATA HD710/validation_generated_QF_0_5_100/'
  127. # path_to_valid_QF_images    = '/media/h2amer/MULTICOM101/jpeg_data/validation_generated_QF/'
  128. #'/Volumes/MULTICOMHD2/validation_generated_QF/';
  129.  
  130. # path_to_valid_QF_images    = '/Volumes/MULTICOM101/jpeg_data/validation_generated_QF/'
  131. path_to_valid_QF_images    = '/Volumes/MULTICOM-104/validation_generated_QF/'
  132.  
  133.  
  134. # Main
  135.  
  136. # Print ops:
  137. # print_ops(sess)
  138. path = '/home/h2amer/work/workspace/Visualization_analysis_jpg_hevc/visualize/visualize_inception_featureMaps/analysis2/'
  139.  
  140. layerID = int(sys.argv[3])
  141. featureMapIdx =  -73
  142. isGrayScaleNorm = True
  143.  
  144.  
  145.  
  146. #读取训练好的Inception-v3模型来创建graph
  147. def create_graph():
  148.   # the class that's been created from the textual definition in graph.proto
  149.   #with tf.gfile.FastGFile('./inception_model/inception_v3_2016_08_28_frozen.pb', 'rb') as f:  
  150.     with tf.gfile.FastGFile(MODEL_PATH + '/classify_image_graph_def.pb', 'rb') as f:  
  151.         graph_def = tf.GraphDef()
  152.         graph_def.ParseFromString(f.read())
  153.         tf.import_graph_def(graph_def, name='')
  154.  
  155.  
  156. def print_ops(sess):
  157.   constant_ops = [op for op in sess.graph.get_operations() if op.type == "Const"]
  158.   for constant_op in constant_ops:
  159.     print(constant_op.name)  
  160.  
  161. os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3' # Hide the warning information from Tensorflow - annoying...
  162.  
  163.  
  164.  
  165.  
  166.  
  167.  
  168.  
  169. def show_image(imgID, QF, image_data, isCast = True):
  170.     # Parse the YUV and convert it into RGB
  171.     # original_img_ID = imgID
  172.     # imgID = str(imgID).zfill(8)
  173.     # shard_num  = round(original_img_ID/10000);
  174.     # folder_num = math.ceil(original_img_ID/1000);
  175.     original_img_ID = imgID
  176.     print('SHOW IMAGE: ', isCast)
  177.     imgID = str(imgID).zfill(8)
  178.     shard_num  = math.floor((original_img_ID - 1) / 10000)
  179.     folder_num = math.ceil(original_img_ID/1000)+1;
  180.     if (((original_img_ID-1)/1000.0)==folder_num-1):
  181.         folder_num = (original_img_ID-1)/1000
  182.    
  183.     if (folder_num == original_img_ID/1000 ):
  184.         folder_num = folder_num + 1
  185.     if ((folder_num-1)*1000==original_img_ID):
  186.         folder_num = folder_num - 1
  187.  
  188.     if not isCast:
  189.         if QF == 110:
  190.             image = path_to_valid_images + str(folder_num) + '/ILSVRC2012_val_' + imgID + '.JPEG'
  191.             figure_title = 'ILSVRC2012_val_' + imgID + '.JPEG'
  192.         else:
  193.             # shard_num = math.floor((original_img_ID - 1) / 10000)
  194.             # folder_num = math.ceil(original_img_ID/1000)
  195.             shard_num = math.floor((original_img_ID - 1) / 10000)
  196.             folder_num = math.ceil(original_img_ID/1000)+1;
  197.             if (((original_img_ID-1)/1000.0)==folder_num-1):
  198.                 folder_num = (original_img_ID-1)/1000
  199.                 print('here1')
  200.             if (folder_num == original_img_ID/1000 ):
  201.                 folder_num = folder_num + 1
  202.             if ((folder_num-1)*1000==original_img_ID):
  203.                 folder_num = folder_num - 1
  204.             #image = path_to_valid_QF_images + str(folder_num) + '/ILSVRC2012_val_' + imgID + '-QF-' + str(QF) + '.JPEG'
  205.             image = path_to_valid_QF_images + 'shard-' + str(int(shard_num)) + '/' + str(int(folder_num)) + '/ILSVRC2012_val_' + imgID + '-QF-' + str(QF) + '.JPEG'
  206.             figure_title = 'ILSVRC2012_val_' + imgID + '-QF-' + str(QF) + '.JPEG'
  207.  
  208.             print('Show image JPEG: ', image)
  209.             image_data = cv2.imread(image)
  210.  
  211.  
  212.         print('Save JPEG')
  213.         filename = './'+imgID+'/ILSVRC2012_val_' + imgID + '-QF-' + str(QF) +'.bmp'
  214.         savefile_ex(filename)
  215.         cv2.imwrite(filename, image_data)
  216.         # cv2.imwrite('/Users/ahamsala/Documents/7.visualize_code/Visualization_analysis_jpg_hevc/image1.bmp', image_data)
  217.         print(image)
  218.     else:
  219.         path_to_recons = PATH_TO_RECONS
  220.         # Get files list to fetch the correct name
  221.         print('PATH TO GLOB: ', path_to_recons + str(int(folder_num)) + '/ILSVRC2012_val_' + imgID + '*.yuv')
  222.         filesList = glob.glob(path_to_recons + str(int(folder_num)) + '/ILSVRC2012_val_' + imgID + '*.yuv')
  223.         name = filesList[0].split('/')[-1]
  224.         rgbStr = name.split('_')[5]
  225.         width  = int(name.split('_')[-4])
  226.         height = int(name.split('_')[-3])
  227.         is_gray_str = name.split('_')[-2]
  228.         figure_title = 'ILSVRC2012_val_' + imgID + '_' + str(width) + '_' + str(height) + '_' + rgbStr + '_' + str(QF) + '_1.yuv'
  229.         image = path_to_recons + str(int(folder_num)) + '/ILSVRC2012_val_' + imgID + '_' + str(width) + '_' + str(height) + '_' + rgbStr + '_' + str(QF) + '.yuv'
  230.         print('Save: ', image)
  231.         size = (height, width)
  232.         videoObj = VideoCaptureYUV(image, size, isGrayScale=is_gray_str.__contains__('Y'))
  233.         ret, yuv, rgb = videoObj.getYUVAndRGB()
  234.         image_data = rgb
  235.  
  236.         print('Save HEVC')
  237.         filename = './'+imgID+'/ILSVRC2012_val_' + imgID + '_' + str(width) + '_' + str(height) + '_' + rgbStr + '_' + str(QF) +'.bmp'
  238.         savefile_ex(filename)
  239.         cv2.imwrite(filename, image_data)
  240.         print(image)
  241.  
  242.     # plt.figure(100*QF)
  243.     # plt.imshow(image_data)
  244.     # plt.suptitle(figure_title, fontsize=16)
  245.  
  246.  
  247.  
  248.  
  249. def get_image_data(imgID, QF, isCast = True):
  250.     # Parse the YUV and convert it into RGB
  251.     original_img_ID = imgID
  252.     imgID = str(imgID).zfill(8)
  253.     shard_num  = math.floor((original_img_ID - 1) / 10000)
  254.     folder_num = math.ceil(original_img_ID/1000)+1;
  255.     if (((original_img_ID-1)/1000.0)==folder_num-1):
  256.         folder_num = (original_img_ID-1)/1000
  257.         print('here1')
  258.     if (folder_num == original_img_ID/1000 ):
  259.         folder_num = folder_num + 1
  260.     if ((folder_num-1)*1000==original_img_ID):
  261.         folder_num = folder_num - 1
  262.  
  263.  
  264.     # shard_num  = round(original_img_ID/10000);
  265.     # folder_num = math.ceil(original_img_ID/1000)+1;
  266.     if isCast:
  267.         path_to_recons = PATH_TO_RECONS
  268.         # Get files list to fetch the correct name
  269.         filesList = glob.glob(path_to_recons + str(int(folder_num)) + '/ILSVRC2012_val_' + imgID + '*.yuv')
  270.         print(path_to_recons + str(int(folder_num)) + '/ILSVRC2012_val_' + imgID + '*.yuv')
  271.         name = filesList[0].split('/')[-1]
  272.         rgbStr = name.split('_')[5]
  273.         width  = int(name.split('_')[-4])
  274.         height = int(name.split('_')[-3])
  275.         is_gray_str = name.split('_')[-2]
  276.        
  277.         image = path_to_recons + str(int(folder_num)) + '/ILSVRC2012_val_' + imgID + '_' + str(width) + '_' + str(height) + '_' + rgbStr + '_' + str(QF) + '.yuv'
  278.         figure_title = 'ILSVRC2012_val_' + imgID + '_' + str(width) + '_' + str(height) + '_' + rgbStr + '_' + str(QF) + '_1.yuv'
  279.         print(image)
  280.         size = (height, width) # height and then width
  281.         videoObj = VideoCaptureYUV(image, size, isGrayScale=is_gray_str.__contains__('Y'))
  282.         ret, yuv, rgb = videoObj.getYUVAndRGB()
  283.         image_data = rgb
  284.  
  285.     else:
  286.         if QF == 110:
  287.             image = path_to_valid_images + str(folder_num) + '/ILSVRC2012_val_' + imgID + '.JPEG'
  288.             figure_title = 'ILSVRC2012_val_' + imgID + '.JPEG'
  289.         else:
  290.             # shard_num = math.floor((original_img_ID - 1) / 10000)
  291.             # folder_num = math.ceil(original_img_ID/1000)
  292.             shard_num = math.floor((original_img_ID - 1) / 10000)
  293.             folder_num = math.ceil(original_img_ID/1000)+1;
  294.             if (((original_img_ID-1)/1000.0)==folder_num-1):
  295.                 folder_num = (original_img_ID-1)/1000
  296.                 print('here1')
  297.             if (folder_num == original_img_ID/1000 ):
  298.                 folder_num = folder_num + 1
  299.             if ((folder_num-1)*1000==original_img_ID):
  300.                 folder_num = folder_num - 1
  301.             #image = path_to_valid_QF_images + str(folder_num) + '/ILSVRC2012_val_' + imgID + '-QF-' + str(QF) + '.JPEG'
  302.             image = path_to_valid_QF_images + 'shard-' + str(int(shard_num)) + '/' + str(int(folder_num)) + '/ILSVRC2012_val_' + imgID + '-QF-' + str(QF) + '.JPEG'
  303.             figure_title = 'ILSVRC2012_val_' + imgID + '-QF-' + str(QF) + '.JPEG'
  304.         print(image)
  305.         image_data = tf.gfile.FastGFile(image, 'rb').read()
  306.     return image_data, figure_title
  307.  
  308.  
  309.  
  310. def plot(feature_maps, featureMapIdx, figure_title, isCast, imgID, bin_num):
  311.     K     = feature_maps.shape[2]
  312.     nRows = K//8
  313.     nCols = K//nRows
  314.     if featureMapIdx < 0:
  315.         fig, ax = plt.subplots(nrows=nRows, ncols=nCols, figsize=(10, 5))
  316.         plt.figure(figureID)
  317.         for irow, row in enumerate(ax):
  318.             for icol, col in enumerate(row):
  319.                 idx = irow + icol * nRows
  320.                 if idx >= K:
  321.                     continue
  322.  
  323.                 m = feature_maps[:, :, idx]
  324.                 if isGrayScaleNorm:
  325.                     A   = np.double(m)
  326.                     out = np.zeros(A.shape, np.double)
  327.                     m   = cv2.normalize(A, out, 255.0, 0.0, cv2.NORM_MINMAX)
  328.                     col.imshow(m, cmap='gray', vmin=0.0, vmax=255.0)
  329.                 else:
  330.                     col.imshow(m)
  331.                 col.axis('off')
  332.     else:
  333.         plt.figure(figureID)
  334.         m = feature_maps[:, :, featureMapIdx]
  335.         if isGrayScaleNorm:
  336.             A   = np.double(m)
  337.             out = np.zeros(A.shape, np.double)
  338.             m   = cv2.normalize(A, out, 255.0, 0.0, cv2.NORM_MINMAX)
  339.             plt.imshow(m, cmap='gray', vmin=0.0, vmax=255.0)
  340.         else:
  341.             plt.imshow(m)
  342.             plt.axis('off')
  343.         figure_title = str(featureMapIdx) + ')' + figure_title
  344.     plt.suptitle(figure_title, fontsize=16)
  345.     plt.subplots_adjust(wspace=0.0, hspace=0.0)
  346.    
  347.     if isCast:
  348.         filename = './'+str(imgID).zfill(8)+'/layerID_'+str(layerID)+'_'+str(bin_num)+'_'+str(imgID)+'/'+str(imgID) + '_' + str(bin_num) + '_hevc.png'
  349.         savefile_ex(filename)
  350.         plt.savefig(filename, dpi=600)
  351.         # plt.savefig(str(imgID) + '_' + str(bin_num) + '_hevc.png', dpi=600)
  352.     else:
  353.         filename = './'+str(imgID).zfill(8)+'/layerID_'+str(layerID)+'_'+str(bin_num)+'_'+str(imgID)+'/'+str(imgID) + '_' + str(bin_num) + '_jpg.png'
  354.         savefile_ex(filename)
  355.         plt.savefig(filename, dpi=600)
  356.         # plt.savefig(str(imgID) + '_' +str(bin_num) + '_jpg.png', dpi=600)
  357.  
  358. def savefile_ex(filename):
  359.     if not os.path.exists(os.path.dirname(filename)):
  360.         try:
  361.             os.makedirs(os.path.dirname(filename))
  362.         except OSError as exc: # Guard against race condition
  363.             if exc.errno != errno.EEXIST:
  364.                 raise
  365.  
  366. # Visualizes feature map of a specific image in the validation set
  367. def visualize_image(imgID, QF, layerID = 1, figureID = 1, isCast = True, isGrayScaleNorm = False, featureMapIdx = -1, bin_num = 0):
  368.  
  369.  
  370.     create_graph()
  371.     config = tf.ConfigProto(device_count = {'GPU': 0})
  372.     #sess = tf.Session()
  373.     sess = tf.Session(config=config)
  374.  
  375.     # Inception-v3: last layer is output as softmax
  376.     conv1_tensor = sess.graph.get_tensor_by_name('conv_' + str(layerID) + ':0')
  377.  
  378.     # Inception-v3: get the softmax tensor
  379.     softmax_tensor = sess.graph.get_tensor_by_name('softmax:0')
  380.  
  381.  
  382.  
  383.  
  384.     # Title of the figure
  385.     figure_title = ''
  386.  
  387.     # Get image data
  388.  
  389.     image_data, figure_title = get_image_data(imgID, QF, isCast)
  390.  
  391.     #Show the image
  392.     show_image(imgID, QF, image_data, isCast)
  393.  
  394.     if isCast:
  395.         feature_maps = sess.run(conv1_tensor, {'Cast:0': image_data}) # n, m, 3
  396.         predictions = sess.run(softmax_tensor,{'Cast:0': image_data}) # n, m, 3
  397.     else:
  398.         feature_maps = sess.run(conv1_tensor, {'DecodeJpeg/contents:0': image_data}) # n, m, 3
  399.         predictions = sess.run(softmax_tensor,{'DecodeJpeg/contents:0': image_data}) # n, m, 3
  400.  
  401.     predictions = np.squeeze(predictions)
  402.  
  403.     # ID --> English string label.
  404.     node_lookup = NodeLookup()
  405.     N = -1008
  406.  
  407.     # Current_rank = -1
  408.     current_rank = -1
  409.  
  410.  
  411.     #(top-5)
  412.     top_5 = predictions.argsort()[N:][::-1]
  413.     for rank, node_id in enumerate(top_5):
  414.         human_string = node_lookup.id_to_string(node_id)
  415.         score = predictions[node_id]
  416.         # if rank < 5:
  417.         #     print('%d: %s (score = %.5f)' % (1 + rank, human_string, score))
  418.  
  419.         # if(gt_label_list[idx+1] == human_string):
  420.         #   print('%d: %s (score = %.5f)' % (1 + rank, human_string, score))
  421.     for idx1, rank_top5 in zip(range(1,6), top_5):
  422.         print('isHEVC: %d, Top-5 -- Node_ID: %d : %s (score = %.20f)' % (int(isCast), int(rank_top5), node_lookup.id_to_string(rank_top5), float(predictions[rank_top5])))
  423.  
  424.     # print(type(feature_maps))
  425.     #print(feature_maps.shape)
  426.  
  427.     feature_maps = np.reshape(feature_maps, [feature_maps.shape[1], feature_maps.shape[2], feature_maps.shape[3]])
  428.     plot(feature_maps, featureMapIdx, figure_title, isCast, imgID, bin_num)
  429.  
  430.     print('GrayScale: ', isGrayScaleNorm)
  431.     print ('Layer ID: %d' % layerID)
  432.     if featureMapIdx > 0:
  433.         print('Feature Map Index: %d' % featureMapIdx)
  434.        
  435.     print('\n')
  436.  
  437. QP = []
  438. QP.append(51)
  439. for i in range(50, -2 , -2):
  440.     QP.append(i)
  441.  
  442. QF = QP[hevc_qp_idx]
  443. figureID = 1
  444. # print('here',imgID)
  445. visualize_image(imgID, QF, layerID, figureID, True, isGrayScaleNorm ,featureMapIdx, bin_num )
  446.  
  447. QF = [i for i in range(0,100,5)]
  448. QF = QF[jpg_qf_idx]
  449. figureID = figureID + 1
  450. visualize_image(imgID, QF, layerID, figureID, False, isGrayScaleNorm , featureMapIdx, bin_num)
  451.  
  452. # QF = 110
  453. # figureID = figureID + 1
  454. # visualize_image(imgID, QF, layerID, figureID, False, isGrayScaleNorm, featureMapIdx)
  455.  
  456.  
  457. # QF = 10
  458. # figureID = figureID + 1
  459. # visualize_image(imgID, QF, layerID, figureID, False, isGrayScaleNorm)
  460.  
  461. # figureID = figureID + 1
  462. # visualize_image(imgID, QF, layerID, figureID, False, isGrayScaleNorm, featureMapIdx)
  463.  
  464.  
  465. # Show plt at the end
  466. plt.show()
RAW Paste Data
We use cookies for various purposes including analytics. By continuing to use Pastebin, you agree to our use of cookies as described in the Cookies Policy. OK, I Understand
 
Top