Advertisement
Not a member of Pastebin yet?
Sign Up,
it unlocks many cool features!
- # extract_features.py
- import os
- import tensorflow as tf
- import numpy as np
- import scipy.io as sio
- from PIL import Image
- import h5py
- image_list = 'data/image_names.txt'
- model_file = 'inception-2015-12-05/classify_image_graph_def.pb'
- feat_file = 'data/inception_v3_features.mat'
- labels_file = 'data/nith_light_labels.mat'
- #outprobs = np.zeros((0,1008))
- outfeats = np.zeros((0,2048))
- labels = np.zeros((0,1))
- f = tf.gfile.FastGFile(model_file,'rb')
- graph_def = tf.GraphDef()
- graph_def.ParseFromString(f.read())
- _ = tf.import_graph_def(graph_def,name='')
- sess = tf.Session()
- #softmax_tensor = sess.graph.get_tensor_by_name('softmax:0')
- feat_tensor = sess.graph.get_tensor_by_name('pool_3:0')
- print('inception v3 model file loaded\n')
- with open(image_list,'r') as list_file:
- image_names = list_file.readlines()
- image_num = len(image_names)
- curnum = 0
- print("total image number: %d\n"%(image_num))
- for image_name in image_names:
- image_name = image_name.strip()
- image_label = image_name.split('/')[-2]
- image_label = image_label[0:3]
- print(image_label)
- if image_label == 'hig':
- labels = np.vstack((labels,np.squeeze(2)))
- elif image_label == 'med':
- labels = np.vstack((labels,np.squeeze(1)))
- else :
- labels = np.vstack((labels,np.squeeze(0)))
- curnum+=1
- print("processing %d/%d: %s\n"%(curnum,image_num,image_name))
- #orgimage = Image.open(image_name)
- #rgbimage = Image.new("RGB",orgimage.size)
- #rgbimage.paste(orgimage)
- mat_image = sio.loadmat(image_name)
- print("datatype")
- print(type(mat_image))
- print(mat_image)
- # image_array = np.asarray([v for v in mat_image.values()])
- image_array = mat_image['data']
- print(type(image_array))
- print(image_array.shape)
- # f = h5py.File(image_name,'r')
- # data = f.get('data/variable1')
- # data = np.array(data) # For converting to numpy array
- image_array = image_array[:,:,[4,3,2]]
- feat = sess.run(feat_tensor,{'DecodeJpeg:0': image_array})
- #image_data = tf.gfile.FastGFile(image_name,'rb').read()
- #feat = sess.run(feat_tensor,{'DecodeJpeg/contents:0': image_data})
- #outprobs = np.vstack((outprobs,pred))
- outfeats = np.vstack((outfeats,np.squeeze(feat)))
- # sio.savemat(prob_file,{'outprobs': outprobs})
- sio.savemat(feat_file,{'outfeats': outfeats})
- sio.savemat(labels_file,{'labels': labels})
- sess.close()
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement