Guest User

Untitled

a guest
Jul 19th, 2018
69
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
text 4.54 KB | None | 0 0
  1. tf.reset_default_graph()
  2. with tf.Graph().as_default():
  3. # hyper-params
  4. learning_rate = 0.0002
  5. epochs = 250
  6. batch_size = 16
  7. N_w = 11 #number of frames concatenated together
  8. channels = 9*N_w
  9. drop_out = [0.5, 0.5, 0.5, 0, 0, 0, 0, 0]
  10.  
  11. def conv_down(x, N, stride, count): #Conv [4x4, str_2] > Batch_Normalization > Leaky_ReLU
  12. with tf.variable_scope("conv_down_{}_{}".format(N, count)) as scope: #N == depth of tensor
  13. with tf.variable_scope("conv_down_4x4_str{}".format(stride)) : #this's used for downsampling
  14. x = tf.layers.conv2d(x, N, kernel_size=4, strides=stride, padding='same', kernel_initializer=tf.truncated_normal_initializer(stddev=np.sqrt(0.2)), name=scope)
  15. x = tf.contrib.layers.batch_norm(x)
  16. x = tf.nn.leaky_relu(x) #for conv_down, implement leakyReLU
  17. return x
  18.  
  19. def conv_up(x, N, drop_rate, stride, count): #Conv_transpose [4x4, str_2] > Batch_Normalizaiton > DropOut > ReLU
  20. with tf.variable_scope("{}".format(count)) as scope:
  21. x = tf.layers.conv2d_transpose(x, N, kernel_size=4, strides=stride, padding='same', kernel_initializer=tf.truncated_normal_initializer(stddev=np.sqrt(0.2)), name=scope)
  22. x = tf.contrib.layers.batch_norm(x)
  23. if drop_rate is not 0:
  24. x = tf.nn.dropout(x, keep_prob=drop_rate)
  25. x = tf.nn.relu(x)
  26. return x
  27.  
  28. def conv_refine(x, N, drop_rate): #Conv [3x3, str_1] > Batch_Normalization > DropOut > ReLU
  29. x = tf.layers.conv2d(x, N, kernel_size=3, strides=1, padding='same', kernel_initializer=tf.truncated_normal_initializer(stddev=np.sqrt(0.2)))
  30. x = tf.contrib.layers.batch_norm(x)
  31. if drop_rate is not 0:
  32. x = tf.nn.dropout(x, keep_prob=drop_rate)
  33. x = tf.nn.relu(x)
  34. return x
  35.  
  36. def conv_upsample(x, N, drop_rate, stride, count):
  37. with tf.variable_scope("conv_upsamp_{}_{}".format(N,count)) :
  38. with tf.variable_scope("conv_up_{}".format(count)):
  39. x = conv_up(x, 2*N, drop_rate, stride,count)
  40. with tf.variable_scope("refine1"):
  41. x = conv_refine(x, N, drop_rate)
  42. with tf.variable_scope("refine2"):
  43. x = conv_refine(x, N, drop_rate)
  44. return x
  45.  
  46. def biLinearDown(x, N):
  47. return tf.image.resize_images(x, [N, N])
  48.  
  49. def finalTanH(x):
  50. return tf.nn.tanh(x)
  51.  
  52. def T(x):
  53. #channel_output_structure
  54. down_channel_output = [64, 128, 256, 512, 512, 512, 512, 512]
  55. up_channel_output= [512, 512, 512, 512, 256, 128, 64, 3]
  56. biLinearDown_output= [32, 64, 128] #for skip-connection
  57.  
  58. #down_sampling
  59. conv1 = conv_down(x, down_channel_output[0], 2, 1)
  60. conv2 = conv_down(conv1, down_channel_output[1], 2, 2)
  61. conv3 = conv_down(conv2, down_channel_output[2], 2, 3)
  62. conv4 = conv_down(conv3, down_channel_output[3], 1, 4)
  63. conv5 = conv_down(conv4, down_channel_output[4], 1, 5)
  64. conv6 = conv_down(conv5, down_channel_output[5], 1, 6)
  65. conv7 = conv_down(conv6, down_channel_output[6], 1, 7)
  66. conv8 = conv_down(conv7, down_channel_output[7], 1, 8)
  67.  
  68. #upsampling
  69. dconv1 = conv_upsample(conv8, up_channel_output[0], drop_out[0], 1, 1)
  70. dconv2 = conv_upsample(dconv1, up_channel_output[1], drop_out[1], 1, 2)
  71. dconv3 = conv_upsample(dconv2, up_channel_output[2], drop_out[2], 1, 3)
  72. dconv4 = conv_upsample(dconv3, up_channel_output[3], drop_out[3], 1, 4)
  73. dconv5 = conv_upsample(dconv4, up_channel_output[4], drop_out[4], 1, 5)
  74. dconv6 = conv_upsample(tf.concat([dconv5, biLinearDown(x, biLinearDown_output[0])], axis=3), up_channel_output[5], drop_out[5], 2, 6)
  75. dconv7 = conv_upsample(tf.concat([dconv6, biLinearDown(x, biLinearDown_output[1])], axis=3), up_channel_output[6], drop_out[6], 2, 7)
  76. dconv8 = conv_upsample(tf.concat([dconv7, biLinearDown(x, biLinearDown_output[2])], axis=3), up_channel_output[7], drop_out[7], 2, 8)
  77.  
  78. #final_tanh
  79. T_x = finalTanH(dconv8)
  80.  
  81. return T_x
  82.  
  83. # input_tensor X
  84. x = tf.placeholder(tf.float32, [batch_size, 256, 256, channels]) # batch_size x Height x Width x N_w
  85.  
  86. # define sheudo_input for testing
  87. sheudo_input = np.float32(np.random.uniform(low=-1., high=1., size=[16, 256,256, 99]))
  88.  
  89. # initialize_
  90. init_g = tf.global_variables_initializer()
  91. init_l = tf.local_variables_initializer()
  92. with tf.Session() as sess:
  93. sess.run(init_g)
  94. sess.run(init_l)
  95. sess.run(T(x), feed_dict={x: sheudo_input})
Add Comment
Please, Sign In to add comment