daily pastebin goal
18%
SHARE
TWEET

Untitled

a guest Mar 19th, 2018 72 Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
  1. for var in tf.trainable_variables():
  2.         if( 'resnet_v2_50/logits/weights' in var.name or 'resnet_v2_50/logits/biases' in var.name ):
  3.             variables.append(var)
  4.    
  5. grads = tf.gradients(loss, variables)
  6.    
  7. saver = tf.train.Saver(var_list=variables)
  8.      ...
  9.      if (global_steps_np + 1) % 1000 == 0:
  10.          saver.save(sess, save_path='model/dogClassification.ckpt', global_step=global_steps_np)
  11.    
  12. import os
  13. import numpy as np
  14. import cv2
  15. import pandas as pd
  16.  
  17. import tensorflow as tf
  18. import tensorflow.contrib as tc
  19. from tensorflow.contrib.slim.nets import resnet_v2
  20.  
  21. is_training = True
  22. slim = tc.slim
  23.  
  24. batch_size = 16
  25. class_types = 120
  26. img_w = resnet_v2.resnet_v2.default_image_size
  27. img_h = resnet_v2.resnet_v2.default_image_size
  28. img_c = 3
  29. global_steps = tf.train.get_or_create_global_step()
  30. decay_steps = 20000
  31. decay_rate = 0.1
  32. max_epoch = 100
  33.  
  34. def dog_generator(data_dir, data_labels_file, img_width, img_height, batch_size, max_epoch):
  35.     reading = True
  36.     file_list = []
  37.     data_labels = []
  38.     print("training image folder: ", data_dir)
  39.     print("training label file: ", data_labels_file)
  40.     labels = pd.read_csv(data_labels_file, header=None)
  41.  
  42.  
  43.     file_list = labels[0]
  44.     data_labels = labels[1]
  45.     data_id = set(data_labels)
  46.     label_from_str_to_id = dict()
  47.     label_from_id_to_str = dict()
  48.     c = 0
  49.     for i in data_id:
  50.         label_from_str_to_id[i] = c
  51.         label_from_id_to_str[c] = i
  52.         c = c + 1
  53.     start = 0
  54.     epoch = 0
  55.     while reading:
  56.         random_order = np.random.permutation(len(file_list))
  57.         x_batch = []
  58.         y_batch = []
  59.         for i in range(batch_size):
  60.             index = 0 if ((i + start) >= len(file_list)) else i + start
  61.             if (index == 0):
  62.                 epoch += 1
  63.             img = cv2.imread(os.path.join(data_dir, file_list[random_order[index]] + '.jpg'))
  64.             img = cv2.resize(img, (img_width, img_height))
  65.             x_batch.append(img)
  66.             y_batch.append(label_from_str_to_id[data_labels[random_order[index]]])
  67.             if (epoch == max_epoch):
  68.                 reading = False
  69.         start = index + 1
  70.         x_batch = np.asarray(x_batch)
  71.         y_batch = np.asarray(y_batch)
  72.         x_batch = x_batch / 128.0 - 1.0
  73.         yield x_batch, y_batch
  74.  
  75.  
  76. data = tf.placeholder(tf.float32, shape=[None, img_w, img_h, img_c])
  77. label = tf.placeholder(tf.int32, shape=[None])
  78.  
  79. if is_training == True:
  80.     g = dog_generator(os.path.join('data', 'train'),
  81.               os.path.join('data', 'labels.csv'), img_w, img_h,
  82.               batch_size, max_epoch)
  83. else:
  84.     g = dog_generator(os.path.join('data', 'train'),
  85.               os.path.join('data', 'labels.csv'), img_w, img_h,
  86.               batch_size, 100)
  87.  
  88.  
  89. def peek(iterable):
  90.     try:
  91.         first, second = next(iterable)
  92.     except StopIteration:
  93.         return None
  94.     return first, second, iterable
  95.  
  96. variables = [global_steps]
  97. variables1001 = []
  98.  
  99. with slim.arg_scope(resnet_v2.resnet_arg_scope()):
  100.     nets, scope = resnet_v2.resnet_v2_50(data, num_classes=class_types, is_training=is_training)
  101.     nets = tf.reshape(nets, [-1, class_types])
  102.     for var in tf.trainable_variables():
  103.         if( 'resnet_v2_50/logits/weights' in var.name or 'resnet_v2_50/logits/biases' in var.name ):
  104.             variables.append(var)
  105.         else:
  106.             variables1001.append(var)
  107.  
  108. with tf.control_dependencies(tf.get_collection(tf.GraphKeys.UPDATE_OPS)):
  109.  
  110.     loss = slim.losses.sparse_softmax_cross_entropy(nets, label)
  111.  
  112.     tf.summary.scalar("loss", loss)
  113.     merged_summary = tf.summary.merge_all()
  114.     pred_classes = tf.argmax(nets, axis=1)
  115.  
  116.     grads = tf.gradients(loss, variables)
  117.  
  118.     lr = tf.train.exponential_decay(learning_rate=0.001, global_step=global_steps, decay_steps=decay_steps,
  119.                                 decay_rate=decay_rate)
  120.     opt = tf.train.AdamOptimizer(lr)
  121.     train_op = opt.apply_gradients(zip(grads, variables), global_step=global_steps)
  122.  
  123.  
  124.  
  125. tf_writer = tf.summary.FileWriter(logdir='./')
  126. saver = tf.train.Saver(var_list=variables)
  127. saverImageNet = tf.train.Saver(var_list=variables1001)
  128.  
  129. with tf.Session() as sess:
  130.     writer = tf.summary.FileWriter('./', sess.graph)
  131.     sess.run(tf.global_variables_initializer())
  132.     saverImageNet.restore(sess, './resnet_v2_50_2017_04_14/resnet_v2_50.ckpt')
  133.     model_file=tf.train.latest_checkpoint('./model/')
  134.     print(model_file)
  135.     saver.restore(sess, tf.train.latest_checkpoint('./model/'))
  136.  
  137.     times = 0
  138.     accuracy = 0
  139.     while True:
  140.         res = peek(g)
  141.         if res == None:
  142.             print("End.")
  143.             break
  144.         else:
  145.             input_imgs, input_labels, g = res
  146.             input_labels = input_labels.astype(np.int32)
  147.  
  148.         if is_training == True:
  149.             _, loss_np, summary, global_steps_np = sess.run([train_op, loss, merged_summary, global_steps], feed_dict={data: input_imgs, label: input_labels})
  150.             writer.add_summary(summary, global_steps_np)
  151.  
  152.             if global_steps_np % 100 == 0:
  153.                 print("times: ", global_steps_np )
  154.                 print("loss: ", loss_np)
  155.  
  156.             if (global_steps_np + 1) % 1000 == 0:
  157.                 saver.save(sess, save_path='model/dogClassification.ckpt', global_step=global_steps_np)
  158.  
  159.         else:
  160.             test, pred_classes_np = sess.run([nets, pred_classes], feed_dict={data: input_imgs})
  161.             print(pred_classes_np)
  162.             print(input_labels)
  163.             times = times + 16
  164.             for i in range(len(pred_classes_np)):
  165.                 if( pred_classes_np[i] == input_labels[i]):
  166.                     accuracy = accuracy + 1
  167.             print(accuracy/times)
RAW Paste Data
We use cookies for various purposes including analytics. By continuing to use Pastebin, you agree to our use of cookies as described in the Cookies Policy. OK, I Understand
 
Top