Advertisement
Not a member of Pastebin yet?
Sign Up,
it unlocks many cool features!
- def _parse_function_3(frames):
- '''
- 分离卷积训练集parsing function
- 输入:tfr解码图文件(连续7帧)
- 输出:连续7帧经过随机crop和增广后的图片
- '''
- im_wid = 192
- im_hig = 192
- cropped_frames = tf.random_crop(frames, (7, im_hig, im_wid, 3))
- '''
- 7张图
- '''
- cropped0_img = cropped_frames[0, ...]
- cropped1_img = cropped_frames[1, ...]
- cropped2_img = cropped_frames[2, ...]
- cropped3_img = cropped_frames[3, ...]
- # cropped4_img = cropped_frames[4, ...]
- # cropped5_img = cropped_frames[5, ...]
- # cropped6_img = cropped_frames[6, ...]
- #### 随机四方向旋转
- rotat = tf.random_uniform([], 0, 4, tf.int32)# 随机旋转
- cropped0_img = tf.image.rot90(cropped0_img, k=rotat)
- cropped1_img = tf.image.rot90(cropped1_img, k=rotat)
- cropped2_img = tf.image.rot90(cropped2_img, k=rotat)
- cropped3_img = tf.image.rot90(cropped3_img, k=rotat)
- # cropped4_img = tf.image.rot90(cropped4_img, k=rotat)
- # cropped5_img = tf.image.rot90(cropped5_img, k=rotat)
- # cropped6_img = tf.image.rot90(cropped6_img, k=rotat)
- #### 随机左右flip
- do_flip = tf.random_uniform([]) > 0.5
- cropped0_img = tf.cond(do_flip, lambda: tf.image.flip_left_right(cropped0_img), lambda: cropped0_img)
- cropped1_img = tf.cond(do_flip, lambda: tf.image.flip_left_right(cropped1_img), lambda: cropped1_img)
- cropped2_img = tf.cond(do_flip, lambda: tf.image.flip_left_right(cropped2_img), lambda: cropped2_img)
- cropped3_img = tf.cond(do_flip, lambda: tf.image.flip_left_right(cropped3_img), lambda: cropped3_img)
- # cropped4_img = tf.cond(do_flip, lambda: tf.image.flip_left_right(cropped4_img), lambda: cropped4_img)
- # cropped5_img = tf.cond(do_flip, lambda: tf.image.flip_left_right(cropped5_img), lambda: cropped5_img)
- # cropped6_img = tf.cond(do_flip, lambda: tf.image.flip_left_right(cropped6_img), lambda: cropped6_img)
- #### 正反向读取 实为concat顺序决定输出顺序,此处的I帧暂时没有
- is_forward = tf.random_uniform([]) > 0.5
- out_image = tf.cond(is_forward, lambda: tf.concat([cropped0_img, cropped1_img, cropped2_img, cropped3_img], 1),
- lambda: tf.concat([cropped3_img, cropped2_img, cropped1_img, cropped0_img], 1))
- out_image = tf.cast(out_image, tf.float32) / 255.0
- return out_image
- def _parse_valid(filename):
- image_string = tf.read_file(filename)
- image_decoded = tf.image.decode_png(image_string, channels=3)
- cropped1I_img = image_decoded[:, :256, :]
- cropped2I_img = image_decoded[:, 256:256*2, :]
- cropped3_img = image_decoded[:, 256*5:, :]
- out_image = tf.concat([cropped1I_img, cropped2I_img, cropped3_img], 1)
- out_image = tf.cast(out_image, tf.float32) / 255.0
- return out_image
- def decode_tfr(serialized_example):
- # Prepare feature list; read encoded frames as bytes
- features = {
- 'shape': tf.FixedLenFeature([], tf.string),
- 'filename': tf.FixedLenFeature([], tf.string),
- 'frames': tf.FixedLenFeature([], tf.string),
- }
- # Parse into tensors
- parsed_features = tf.parse_single_example(serialized_example, features)
- # Decode the encoded frames
- frames = tf.decode_raw(parsed_features['frames'], tf.uint8)
- shape = tf.decode_raw(parsed_features['shape'], tf.int32)
- # the frames tensor is flattened out, so we have to reconstruct the shape
- frames = tf.reshape(frames, shape)
- return frames
- def build_model():
- sess_post, post_ori_image, input_image_p, recon_post, psnr_post = load_post_ckpt()
- sess_joint, joint_ori_image, joint_is_train, recon_joint, psnr_joint = load_joint_ckpt()
- train_filename_list = []
- for x in os.walk(train_set_dir):
- train_filename_list += [y for y in glob.glob(os.path.join(x[0], '*.tfrecord'))]
- assert(train_filename_list)
- print(len(train_filename_list))
- # 好像直接 shuffle filename_list 更快一点
- random.shuffle(train_filename_list)
- # train_dataset = tf.data.Dataset.from_tensor_slices(train_filename_list)
- train_dataset = tf.data.TFRecordDataset(train_filename_list)
- train_dataset = train_dataset.map(decode_tfr, num_parallel_calls=12)
- train_dataset = train_dataset.map(_parse_function_3, num_parallel_calls=12)
- train_dataset = train_dataset.repeat()
- train_dataset = train_dataset.batch(total_batch_size)
- train_dataset = train_dataset.prefetch(12)
- train_iterator = train_dataset.make_one_shot_iterator()
- next_train = train_iterator.get_next()
- #train
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement