Advertisement
Guest User

Untitled

a guest
May 24th, 2017
126
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
text 8.22 KB | None | 0 0
  1. import skimage.io # bug. need to import this before tensorflow
  2. import skimage.transform # bug. need to import this before tensorflow
  3. #from resnet_train import train
  4. import tensorflow as tf
  5. import time
  6. import os
  7. import sys
  8. import re
  9. import numpy as np
  10.  
  11. from synset import *
  12. from image_processing import image_preprocessing
  13.  
  14. import input_data
  15. from resnet34 import inference
  16. from resnet34 import inferencefinetune
  17. from resnet34 import fcf
  18. from resnet34 import *
  19. import tensorflow as tf
  20. import input_data
  21. from checkpoint import print_tensors_in_checkpoint_file
  22. MOMENTUM = 0.9
  23.  
  24. def checkpoint_fn(layers):
  25. return './pretrained/ResNet-L%d.ckpt' % layers
  26.  
  27. def top_k_error(predictions, labels, k):
  28. batch_size = float(FLAGS.batch_size) #tf.shape(predictions)[0]
  29. in_top1 = tf.to_float(tf.nn.in_top_k(predictions, labels, k))
  30. num_correct = tf.reduce_sum(in_top1)
  31. return (batch_size - num_correct) / batch_size
  32.  
  33. os.environ["CUDA_DEVICE_ORDER"]="PCI_BUS_ID" # see issue #152
  34. os.environ["CUDA_VISIBLE_DEVICES"]="5"
  35.  
  36.  
  37. FLAGS = tf.app.flags.FLAGS
  38.  
  39. tf.app.flags.DEFINE_string('train_dir', './tmp/resnet_train_0',
  40. """Directory where to write event logs """
  41. """and checkpoint.""")
  42. tf.app.flags.DEFINE_float('learning_rate', 0.01, "learning rate.")
  43. tf.app.flags.DEFINE_integer('batch_size', 16, "batch size")
  44. tf.app.flags.DEFINE_integer('max_steps', 500000, "max steps")
  45. tf.app.flags.DEFINE_boolean('resume', False,
  46. 'resume from latest saved state')
  47. tf.app.flags.DEFINE_boolean('minimal_summaries', True,
  48. 'produce fewer summaries to save HD space')
  49.  
  50.  
  51. # Path for tf.summary.FileWriter and to store model checkpoints
  52. filewriter_path = "./tmp/finetune_resnet0/log"
  53. checkpoint_path = "./tmp/finetune_resnet0/"
  54.  
  55. # Create parent path if it doesn't exist
  56. if not os.path.isdir(checkpoint_path): os.mkdir(checkpoint_path)
  57.  
  58.  
  59.  
  60. def main(_):
  61.  
  62. is_training = tf.placeholder('bool', [], name='is_training')
  63. # logits, pools = inferencefinetune(images,
  64. # num_classes=1000,
  65. # is_training=True,
  66. # bottleneck=False,
  67. # num_blocks=[3, 4, 6, 3])
  68.  
  69. xx=tf.placeholder(tf.float32,[16,224,224,3])
  70. #images=placeholder
  71. #labels=placehodler
  72.  
  73. logit = inference(xx,
  74. num_classes=41,
  75. is_training=True,
  76. bottleneck=False,
  77. num_blocks=[3, 4, 6, 3]) # num_blocks = [2,2,2,2]
  78.  
  79. global_step = tf.get_variable('global_step', [],
  80. initializer=tf.constant_initializer(0),
  81. trainable=False)
  82. val_step = tf.get_variable('val_step', [],
  83. initializer=tf.constant_initializer(0),
  84. trainable=False)
  85. yy=tf.placeholder(tf.int64, 16)
  86. loss1 = loss(logit, yy )
  87.  
  88.  
  89. predictions = tf.nn.softmax(logit)
  90.  
  91. top1_error = top_k_error(predictions, yy, 1)
  92.  
  93. # Evaluation op: Accuracy of the model
  94. #with tf.name_scope("accuracy"):
  95. # correct_pred = tf.equal(tf.argmax(predictions, 1), labels)
  96. # accuracy = tf.reduce_mean(tf.cast(correct_pred, tf.float32))
  97.  
  98. # Add the accuracy to the summary
  99. #tf.summary.scalar('accuracy', accuracy)
  100.  
  101. # loss_avg
  102. ema = tf.train.ExponentialMovingAverage(MOVING_AVERAGE_DECAY, global_step)
  103. tf.add_to_collection(UPDATE_OPS_COLLECTION, ema.apply([loss1]))
  104. tf.summary.scalar('loss_avg', ema.average(loss1))
  105.  
  106. # validation stats
  107. ema = tf.train.ExponentialMovingAverage(0.9, val_step)
  108. val_op_ = tf.group(val_step.assign_add(1), ema.apply([top1_error]))
  109. top1_error_avg = ema.average(top1_error)
  110. # tf.summary.scalar('val_top1_error_avg', top1_error_avg)
  111.  
  112. # tf.summary.scalar('learning_rate', FLAGS.learning_rate)
  113.  
  114. opt = tf.train.MomentumOptimizer(FLAGS.learning_rate, MOMENTUM)
  115. grads = opt.compute_gradients(loss1)
  116. for grad, var in grads:
  117. if grad is not None and not FLAGS.minimal_summaries:
  118. tf.summary.histogram(var.op.name + '/gradients', grad)
  119. apply_gradient_op = opt.apply_gradients(grads, global_step=global_step)
  120.  
  121. # if not FLAGS.minimal_summaries:
  122. # # Display the training images in the visualizer.
  123. # tf.summary.image('images', images)
  124.  
  125. # for var in tf.trainable_variables():
  126. # tf.summary.histogram(var.op.name, var)
  127.  
  128. batchnorm_updates = tf.get_collection(UPDATE_OPS_COLLECTION)
  129. batchnorm_updates_op = tf.group(*batchnorm_updates)
  130. train_op = tf.group(apply_gradient_op, batchnorm_updates_op)
  131.  
  132.  
  133. summary_op = tf.summary.merge_all()
  134.  
  135. init = tf.global_variables_initializer()
  136.  
  137. sess = tf.Session(config=tf.ConfigProto(log_device_placement=False))
  138. sess.run(init)
  139.  
  140. saver = tf.train.Saver(tf.global_variables())
  141.  
  142. k1=tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES, scope='scale5')
  143. k2=tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES, scope='scale4')
  144. k3=tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES, scope='scale3')
  145. k4=tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES, scope='scale2')
  146. k5=tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES, scope='scale1')
  147. k6=k1+k2+k3+k4+k5
  148. #print(k6)
  149. saver1 = tf.train.Saver(k6)
  150. saver1.restore(sess, './tmp/res50imagenet/model.ckpt-3M')
  151. #saver.restore(sess, './tmp/resnet_train_0/model(test).ckpt-240301')
  152. print('model loaded')
  153. with open('list/train.list','r') as lines:
  154. fcvid_path = '/mnt/hdd/ockwon/tensorflow1/1FPS/fcvid41yt8m3/'
  155. lines1 = list(lines)
  156. with open('list/test.list','r') as linesk:
  157. fcvid_path = '/mnt/hdd/ockwon/tensorflow1/1FPS/fcvid41yt8m3/'
  158. lines2 = list(linesk)
  159. start_time = time.time()
  160. for x in xrange(FLAGS.max_steps+1):
  161. image1, label1, _, _, _ = input_data.read_frame_and_label(
  162. filename=lines1,
  163. batch_size=FLAGS.batch_size,
  164. num_frames_per_clip=1,
  165. crop_size=224,
  166. shuffle=False)
  167. tf.train.start_queue_runners(sess=sess)
  168.  
  169. # summary_writer = tf.summary.FileWriter(FLAGS.train_dir, sess.graph)
  170.  
  171.  
  172.  
  173.  
  174.  
  175. step = sess.run(global_step)
  176. i = [train_op, loss1]
  177.  
  178. write_summary = step % 100 and step > 1
  179. # if write_summary:
  180. # i.append(summary_op)
  181. loss_, val_op = sess.run( [loss1, val_op_], feed_dict={
  182. xx : image1,
  183. yy : label1,
  184. is_training : True
  185. })
  186.  
  187. o = sess.run(i, feed_dict={ is_training: True,
  188. xx : image1,
  189. yy : label1 })
  190.  
  191. loss_value = o[1]
  192.  
  193. #duration = time.time() - start_timeqq
  194.  
  195. assert not np.isnan(loss_value), 'Model diverged with loss = NaN'
  196.  
  197. if step % 10 == 0:
  198. #examples_per_sec = FLAGS.batch_size / float(duration)
  199. format_str = ('step %d, loss = %.2f (%.1f examples/sec; %.3f '
  200. 'sec/batch)')
  201. print('step', step, 'loss', loss_value)
  202.  
  203. # if write_summary:
  204. # summary_str = o[2]
  205. # summary_writer.add_summary(summary_str, step)
  206.  
  207. # Save the model checkpoint periodically.
  208. if step > 1 and step % 100 == 0:
  209. checkpoint_path = os.path.join(FLAGS.train_dir, 'model(new).ckpt')
  210. saver.save(sess, checkpoint_path, global_step=global_step)
  211.  
  212. # Run validation periodically
  213. if step > 1 and step % 100 == 0:
  214. image2, label2, _, _, _ = input_data.read_frame_and_label(
  215. filename=lines2,
  216. batch_size=FLAGS.batch_size,
  217. num_frames_per_clip=1,
  218. crop_size=224,
  219. shuffle=False)
  220. #acc = sess.run([logits, accuracy], {is_training: False}, )
  221. #_, acc = sess.run([val_op, accuracy], {is_training: False})
  222. #print('Validation top1 accuracy %.2f' % acc)
  223.  
  224. _, top1_error_value = sess.run([val_op_, top1_error], feed_dict={ is_training: False,
  225. xx : image2,
  226. yy : label2 })
  227. print('Validation top1 error %.2f' % top1_error_value) #cf = {}mag
  228.  
  229.  
  230. if __name__ == '__main__':
  231. tf.app.run()
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement