Advertisement
Guest User

Untitled

a guest
Aug 20th, 2019
92
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
text 4.42 KB | None | 0 0
  1. def _parse_function_3(frames):
  2. '''
  3. 分离卷积训练集parsing function
  4. 输入:tfr解码图文件(连续7帧)
  5. 输出:连续7帧经过随机crop和增广后的图片
  6. '''
  7.  
  8. im_wid = 192
  9. im_hig = 192
  10. cropped_frames = tf.random_crop(frames, (7, im_hig, im_wid, 3))
  11.  
  12. '''
  13. 7张图
  14. '''
  15. cropped0_img = cropped_frames[0, ...]
  16. cropped1_img = cropped_frames[1, ...]
  17. cropped2_img = cropped_frames[2, ...]
  18. cropped3_img = cropped_frames[3, ...]
  19. # cropped4_img = cropped_frames[4, ...]
  20. # cropped5_img = cropped_frames[5, ...]
  21. # cropped6_img = cropped_frames[6, ...]
  22.  
  23.  
  24. #### 随机四方向旋转
  25. rotat = tf.random_uniform([], 0, 4, tf.int32)# 随机旋转
  26. cropped0_img = tf.image.rot90(cropped0_img, k=rotat)
  27. cropped1_img = tf.image.rot90(cropped1_img, k=rotat)
  28. cropped2_img = tf.image.rot90(cropped2_img, k=rotat)
  29. cropped3_img = tf.image.rot90(cropped3_img, k=rotat)
  30. # cropped4_img = tf.image.rot90(cropped4_img, k=rotat)
  31. # cropped5_img = tf.image.rot90(cropped5_img, k=rotat)
  32. # cropped6_img = tf.image.rot90(cropped6_img, k=rotat)
  33.  
  34.  
  35. #### 随机左右flip
  36. do_flip = tf.random_uniform([]) > 0.5
  37. cropped0_img = tf.cond(do_flip, lambda: tf.image.flip_left_right(cropped0_img), lambda: cropped0_img)
  38. cropped1_img = tf.cond(do_flip, lambda: tf.image.flip_left_right(cropped1_img), lambda: cropped1_img)
  39. cropped2_img = tf.cond(do_flip, lambda: tf.image.flip_left_right(cropped2_img), lambda: cropped2_img)
  40. cropped3_img = tf.cond(do_flip, lambda: tf.image.flip_left_right(cropped3_img), lambda: cropped3_img)
  41. # cropped4_img = tf.cond(do_flip, lambda: tf.image.flip_left_right(cropped4_img), lambda: cropped4_img)
  42. # cropped5_img = tf.cond(do_flip, lambda: tf.image.flip_left_right(cropped5_img), lambda: cropped5_img)
  43. # cropped6_img = tf.cond(do_flip, lambda: tf.image.flip_left_right(cropped6_img), lambda: cropped6_img)
  44.  
  45.  
  46. #### 正反向读取 实为concat顺序决定输出顺序,此处的I帧暂时没有
  47. is_forward = tf.random_uniform([]) > 0.5
  48. out_image = tf.cond(is_forward, lambda: tf.concat([cropped0_img, cropped1_img, cropped2_img, cropped3_img], 1),
  49. lambda: tf.concat([cropped3_img, cropped2_img, cropped1_img, cropped0_img], 1))
  50. out_image = tf.cast(out_image, tf.float32) / 255.0
  51. return out_image
  52.  
  53. def _parse_valid(filename):
  54. image_string = tf.read_file(filename)
  55. image_decoded = tf.image.decode_png(image_string, channels=3)
  56. cropped1I_img = image_decoded[:, :256, :]
  57. cropped2I_img = image_decoded[:, 256:256*2, :]
  58. cropped3_img = image_decoded[:, 256*5:, :]
  59. out_image = tf.concat([cropped1I_img, cropped2I_img, cropped3_img], 1)
  60. out_image = tf.cast(out_image, tf.float32) / 255.0
  61. return out_image
  62.  
  63. def decode_tfr(serialized_example):
  64. # Prepare feature list; read encoded frames as bytes
  65. features = {
  66. 'shape': tf.FixedLenFeature([], tf.string),
  67. 'filename': tf.FixedLenFeature([], tf.string),
  68. 'frames': tf.FixedLenFeature([], tf.string),
  69. }
  70.  
  71. # Parse into tensors
  72. parsed_features = tf.parse_single_example(serialized_example, features)
  73.  
  74. # Decode the encoded frames
  75. frames = tf.decode_raw(parsed_features['frames'], tf.uint8)
  76. shape = tf.decode_raw(parsed_features['shape'], tf.int32)
  77.  
  78. # the frames tensor is flattened out, so we have to reconstruct the shape
  79. frames = tf.reshape(frames, shape)
  80.  
  81. return frames
  82.  
  83. def build_model():
  84. sess_post, post_ori_image, input_image_p, recon_post, psnr_post = load_post_ckpt()
  85. sess_joint, joint_ori_image, joint_is_train, recon_joint, psnr_joint = load_joint_ckpt()
  86.  
  87. train_filename_list = []
  88. for x in os.walk(train_set_dir):
  89. train_filename_list += [y for y in glob.glob(os.path.join(x[0], '*.tfrecord'))]
  90. assert(train_filename_list)
  91. print(len(train_filename_list))
  92. # 好像直接 shuffle filename_list 更快一点
  93. random.shuffle(train_filename_list)
  94.  
  95. # train_dataset = tf.data.Dataset.from_tensor_slices(train_filename_list)
  96. train_dataset = tf.data.TFRecordDataset(train_filename_list)
  97. train_dataset = train_dataset.map(decode_tfr, num_parallel_calls=12)
  98. train_dataset = train_dataset.map(_parse_function_3, num_parallel_calls=12)
  99. train_dataset = train_dataset.repeat()
  100. train_dataset = train_dataset.batch(total_batch_size)
  101. train_dataset = train_dataset.prefetch(12)
  102. train_iterator = train_dataset.make_one_shot_iterator()
  103. next_train = train_iterator.get_next()
  104. #train
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement