Advertisement
Guest User

Untitled

a guest
Apr 23rd, 2019
94
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
Python 20.90 KB | None | 0 0
  1. import tensorflow as tf
  2. import numpy as np
  3. import MSSSIM
  4.  
  5. # reset graph
  6. tf.reset_default_graph()
  7.  
  8.  
  9. # Input functions ------------------------------------------------------------------------------------------------------
  10.  
  11.  
  12. def _parse_function(example_proto):
  13.     keys_to_features = {'image/encoded': tf.VarLenFeature(tf.string)}
  14.     parsed_features = tf.parse_example(example_proto, keys_to_features)
  15.     raw = tf.sparse_tensor_to_dense(parsed_features['image/encoded'], default_value="0", )
  16.  
  17.     return (tf.map_fn(decode_random_crop, tf.squeeze(raw), dtype=tf.uint8, back_prop=False))
  18.  
  19.  
  20. def decode_random_crop(raw):
  21.     img = tf.image.decode_jpeg(raw, channels=3, try_recover_truncated=True, acceptable_fraction=0.5)
  22.  
  23.     return tf.cast(tf.squeeze(tf.image.crop_and_resize(tf.expand_dims(img, axis=0), [[0., 0., 1., 1.]], [0], [200, 200])), dtype=tf.uint8)
  24.  
  25.  
  26. def get_train_dataset():
  27.     files = tf.data.Dataset.list_files("train/train-*")
  28.     dataset = files.interleave(tf.data.TFRecordDataset, cycle_length=1)
  29.     dataset = dataset.apply(tf.contrib.data.batch_and_drop_remainder(16))
  30.     dataset = dataset.shuffle(3500).repeat()
  31.     dataset = dataset.map(_parse_function)
  32.     dataset = dataset.prefetch(30)
  33.  
  34.     return dataset
  35.  
  36.  
  37. def get_test_dataset():
  38.     files = tf.data.Dataset.list_files("/path/to/validation/validation-*")
  39.     dataset = files.interleave(tf.data.TFRecordDataset, cycle_length=1)
  40.     dataset = dataset.batch(100).map(_parse_function)
  41.  
  42.     return dataset
  43.  
  44.  
  45. # MS-SSIM functions ------------------------------------------------------------------------------------------------------------------
  46.  
  47. # dataset and iterator initialization
  48.  
  49. training_dataset = get_train_dataset()
  50. test_dataset = get_test_dataset()
  51.  
  52. iterator = tf.data.Iterator.from_structure(training_dataset.output_types,
  53.                                            training_dataset.output_shapes)
  54.  
  55. # variables initialization ------------------------------------------------------------------------------------------------------------
  56.  
  57. # network hyper-parameter
  58. batch_size = 16
  59. n_update = 20000 * 10
  60.  
  61. # sigma = tf.constant(1.)
  62. depth = 5  # Depth residual block for the AutoEncoder
  63. lr0 = 1e-4  # Learning rate
  64. regularizer = tf.contrib.layers.l2_regularizer(scale=0.01)  # Regularization term for all layers
  65. regularizer2 = tf.contrib.layers.l2_regularizer(scale=0.1)  # Regularization term for layer that outputs y
  66. initializer = None  # tf.initializers.he_normal()
  67. image_height = 200
  68. image_width = 200
  69.  
  70. beta = 1000
  71.  
  72. bpp_targ = 0.4
  73. S = int(bpp_targ * image_height * image_width / 2)
  74. K = 32
  75. L = 4
  76.  
  77. huffman = [1, 2, 3, 3]
  78.  
  79. # Graph definition ------------------------------------------------------------------------------------------------------------------
  80.  
  81. # Network Placeholders
  82. lr = tf.placeholder(shape=(), dtype=tf.float32)
  83. training = tf.placeholder(dtype=tf.bool, shape=(), name="isTraining")
  84. file_path = tf.placeholder(tf.string, name="path")
  85.  
  86. constant = (tf.log(tf.cast(L, tf.float32)) / tf.log(2.)) / tf.cast(image_height * image_width, tf.float32)
  87.  
  88. # Read input from pipeline
  89. with tf.device('/cpu:0'):
  90.     # image is in range [0, 1]
  91.     x = tf.reshape(tf.image.convert_image_dtype(iterator.get_next(), dtype=tf.float32), [batch_size, 200, 200, 3])
  92.     # filenames = iterator.get_next()[1]
  93.  
  94. # Encoder
  95.  
  96. # [mean, var] = tf.nn.moments(x, axes=[0, 1, 2])
  97. # mean = tf.transpose(tf.expand_dims(tf.expand_dims(tf.expand_dims(mean, -1), -1), -1), [1, 2, 3, 0])
  98. # var = tf.transpose(tf.expand_dims(tf.expand_dims(tf.expand_dims(var, -1), -1), -1), [1, 2, 3, 0])
  99. #
  100. # x_n = (x - mean) / tf.sqrt(var + 1e-10)
  101.  
  102. with tf.name_scope("Encoder"):
  103.     conv1 = tf.layers.conv2d(inputs=x,
  104.                              filters=64,
  105.                              kernel_size=[5, 5],
  106.                              strides=(2, 2),
  107.                              padding="same",
  108.                              kernel_regularizer=regularizer,
  109.                              kernel_initializer=initializer,
  110.                              name="conv1")
  111.  
  112.     conv1 = tf.layers.batch_normalization(inputs=conv1, training=training)
  113.     conv1 = tf.nn.relu(conv1)
  114.     conv2 = tf.layers.conv2d(inputs=conv1,
  115.                              filters=128,
  116.                              kernel_size=[5, 5],
  117.                              strides=(2, 2),
  118.                              padding="same",
  119.                              kernel_regularizer=regularizer,
  120.                              kernel_initializer=initializer,
  121.                              name="conv2")
  122.  
  123.     conv2 = tf.layers.batch_normalization(inputs=conv2, training=training)
  124.     conv2 = tf.nn.relu(conv2)
  125.  
  126.     E_residual_blocks = []
  127.     tmp = conv2
  128.     for i in range(depth):
  129.         tmp3 = tmp
  130.         for j in range(3):
  131.             tmp2 = tmp
  132.             E_residual_blocks.append(tf.layers.conv2d(inputs=tmp,
  133.                                                       filters=128,
  134.                                                       kernel_size=[3, 3],
  135.                                                       strides=(1, 1),
  136.                                                       padding="same",
  137.                                                       kernel_regularizer=regularizer,
  138.                                                       kernel_initializer=initializer,
  139.                                                       name="conv" + str(6 * i + 2 * j + 3)))
  140.  
  141.             E_residual_blocks[-1] = tf.layers.batch_normalization(inputs=E_residual_blocks[-1], training=training)
  142.             E_residual_blocks[-1] = tf.nn.relu(E_residual_blocks[-1])
  143.             tmp = E_residual_blocks[-1]
  144.             E_residual_blocks.append(tf.layers.conv2d(inputs=tmp,
  145.                                                       filters=128,
  146.                                                       kernel_size=[3, 3],
  147.                                                       strides=(1, 1),
  148.                                                       padding="same",
  149.                                                       kernel_regularizer=regularizer,
  150.                                                       kernel_initializer=initializer,
  151.                                                       name="conv" + str(6 * i + 2 * j + 4)))
  152.  
  153.             tmp = E_residual_blocks[-1] + tmp2
  154.         tmp = tmp3 + tmp
  155.  
  156.     tmp2 = tmp
  157.     E_residual_blocks.append(tf.layers.conv2d(inputs=tmp,
  158.                                               filters=128,
  159.                                               kernel_size=[3, 3],
  160.                                               strides=(1, 1),
  161.                                               padding="same",
  162.                                               kernel_regularizer=regularizer,
  163.                                               kernel_initializer=initializer,
  164.                                               name="conv" + str(depth * 6 + 3)))
  165.  
  166.     E_residual_blocks[-1] = tf.layers.batch_normalization(inputs=E_residual_blocks[-1], training=training)
  167.     E_residual_blocks[-1] = tf.nn.relu(E_residual_blocks[-1])
  168.     tmp = E_residual_blocks[-1]
  169.     E_residual_blocks.append(tf.layers.conv2d(inputs=tmp,
  170.                                               filters=128,
  171.                                               kernel_size=[3, 3],
  172.                                               strides=(1, 1),
  173.                                               padding="same",
  174.                                               kernel_regularizer=regularizer,
  175.                                               kernel_initializer=initializer,
  176.                                               name="conv" + str(depth * 6 + 4)))
  177.  
  178.     tmp = E_residual_blocks[-1] + tmp2 + conv2
  179.  
  180.     e_out = tf.layers.conv2d(inputs=tmp,
  181.                              filters=K,
  182.                              kernel_size=[5, 5],
  183.                              strides=(2, 2),
  184.                              padding="same",
  185.                              kernel_regularizer=regularizer,
  186.                              kernel_initializer=initializer,
  187.                              name="conv" + str(depth * 6 + 5))
  188.  
  189.     z = tf.layers.conv3d(inputs=tf.expand_dims(e_out, axis=-1),
  190.                          filters=L,
  191.                          kernel_size=[1, 1, 1],
  192.                          strides=(1, 1, 1),
  193.                          padding="same",
  194.                          kernel_regularizer=regularizer,
  195.                          kernel_initializer=initializer,
  196.                          name="conv3d")
  197.  
  198.     z_soft = tf.nn.softmax(z, axis=-1)  # probability for each value to be every symbol
  199.  
  200.     cm1 = tf.layers.conv2d(inputs=tmp,
  201.                            filters=24,
  202.                            kernel_size=[3, 3],
  203.                            strides=(1, 1),
  204.                            padding="same",
  205.                            kernel_regularizer=regularizer,
  206.                            kernel_initializer=initializer,
  207.                            name="conv_cm1")
  208.  
  209.     cm1 = tf.nn.relu(cm1)
  210.  
  211.     cm2 = tf.layers.conv2d(inputs=cm1,
  212.                            filters=24,
  213.                            kernel_size=[3, 3],
  214.                            strides=(1, 1),
  215.                            padding="same",
  216.                            kernel_regularizer=regularizer,
  217.                            kernel_initializer=initializer,
  218.                            name="conv_cm2")
  219.  
  220.     cm2 = tf.nn.relu(cm2)
  221.  
  222.     cm3 = tf.layers.conv2d(inputs=cm2,
  223.                            filters=24,
  224.                            kernel_size=[3, 3],
  225.                            strides=(1, 1),
  226.                            padding="same",
  227.                            kernel_regularizer=regularizer,
  228.                            kernel_initializer=initializer,
  229.                            name="conv_cm3")
  230.  
  231.     cm3 = cm1 + cm3
  232.  
  233.     y_out = tf.layers.conv2d(inputs=cm3,
  234.                              filters=1,
  235.                              kernel_size=[5, 5],
  236.                              strides=(2, 2),
  237.                              padding="same",
  238.                              kernel_regularizer=regularizer2,
  239.                              kernel_initializer=initializer,
  240.                              name="conv" + str(depth * 6 + 6))
  241.  
  242.     #y_out = tf.nn.tanh(y_out)
  243.  
  244.     weights_y = tf.get_default_graph().get_tensor_by_name("conv" + str(depth * 6 + 6) + "/kernel:0")
  245.  
  246.     mean_w = tf.reduce_mean(tf.abs(weights_y), [0, 1, 2, 3])
  247.  
  248.     tf.summary.scalar("y_out_weights_mean", mean_w)
  249.  
  250.     # y_out = tf.layers.batch_normalization(inputs=y_out, training=training)
  251.     # y_out = tf.nn.sigmoid(y_out)
  252.  
  253. with tf.name_scope("Mask"):
  254.  
  255.     shape = tf.shape(y_out)
  256.  
  257.     y_max = tf.reduce_max(tf.abs(y_out), axis=[1, 2])
  258.     y_mean = tf.reduce_mean(tf.abs(y_out), axis=[1, 2])
  259.  
  260.     tf.summary.scalar("Max_value_mask", tf.reduce_mean(y_max))
  261.     tf.summary.scalar("Mean_value_mask", tf.reduce_mean(y_mean))
  262.  
  263.     y = tf.exp(y_out)
  264.  
  265.     y_exp_mean = tf.reduce_mean(tf.abs(y), axis=[1, 2])
  266.  
  267.     tf.summary.scalar("Mean_value_y_exp", tf.reduce_mean(y_exp_mean))
  268.  
  269.     y_sum = tf.reshape(tf.reduce_sum(y, axis=[1, 2]), [-1, 1])
  270.     tile = tf.tile(y_sum, [1, shape[1] * shape[2]])
  271.  
  272.     y = tf.div(y, tf.reshape(tile, shape))
  273.  
  274.     tf.summary.image("prob", y, 5)
  275.  
  276.     yy = tf.transpose(tf.reshape(tf.tile(tf.reshape(y,
  277.                                                     [-1]),
  278.                                          [K]),
  279.                                  [K, tf.shape(y)[0], tf.shape(y)[1], tf.shape(y)[2]]),
  280.                       [1, 2, 3, 0])
  281.  
  282.     kk = tf.transpose(tf.reshape(tf.tile(tf.linspace(0., K - 1, K),
  283.                                          [np.prod(y.get_shape().as_list())]),
  284.                                  [tf.shape(y)[0], tf.shape(y)[1], tf.shape(y)[2], K]),
  285.                       [0, 1, 2, 3])
  286.  
  287.     m = yy * S - kk
  288.     zero = tf.zeros(tf.shape(m))
  289.     m = tf.maximum(x=m, y=zero)
  290.     zero = tf.ones(tf.shape(m))
  291.     m = tf.minimum(x=m, y=zero)
  292.  
  293.     # gradient trick
  294.     m = m + tf.stop_gradient(tf.ceil(m) - m)
  295.  
  296.     number_symbols = tf.reduce_sum(m, [1, 2, 3])
  297.     tf.summary.scalar("number_of_symbols", tf.reduce_mean(number_symbols))
  298.  
  299.     bpp = number_symbols * constant
  300.     tf.summary.scalar("bpp", tf.reduce_mean(bpp))
  301.  
  302. with tf.name_scope("Quantizer"):
  303.     z_scaled = z * 1.
  304.  
  305.     z_hat = tf.cast(tf.argmax(z_soft, axis=-1) + 1, dtype=tf.float32)
  306.  
  307.     cs = tf.cumsum(tf.ones_like(z), axis=-1)
  308.  
  309.     z_soft_scaled = tf.nn.softmax((beta * z_scaled), axis=-1)
  310.     z_tilde = tf.reduce_sum((z_soft_scaled * cs), axis=-1)
  311.  
  312.     quant_error = tf.reduce_mean(tf.abs(z_hat - z_tilde), axis=[0, 1, 2, 3])
  313.     tf.summary.scalar('Quantization_error', quant_error)
  314.  
  315.     z_differentiable = tf.stop_gradient(z_hat - z_tilde) + z_tilde
  316.  
  317.     z_masked = tf.multiply(z_differentiable, m)
  318.     #z_masked = z_differentiable
  319.  
  320.     '''
  321.  
  322.    z_compressed = tf.cast(tf.multiply(tf.cast(tf.argmax(z_soft, axis=-1) + 1, tf.float32), m), tf.int32)
  323.    frequency = tf.bincount(z_compressed)
  324.  
  325.    bits = 0
  326.    for b in range(batch_size):
  327.        frequency = tf.sort(tf.bincount(z_compressed[b], minlength=L + 1)[1:L + 1], direction='DESCENDING')
  328.        for l in range(L):
  329.            bits += frequency[l] * huffman[l]
  330.    bits = bits / batch_size
  331.  
  332.    bpp_h = bits / (image_height * image_width)
  333.  
  334.    tf.summary.scalar('Huffman_bpp', bpp_h)
  335.  
  336.    '''
  337.  
  338. # Decoder
  339. with tf.name_scope("Decoder"):
  340.     D_residual_blocks = []
  341.     D_residual_blocks.append(tf.layers.conv2d_transpose(inputs=z_masked,
  342.                                                         filters=128,
  343.                                                         kernel_size=[3, 3],
  344.                                                         strides=(2, 2),
  345.                                                         padding="same",
  346.                                                         kernel_regularizer=regularizer,
  347.                                                         kernel_initializer=initializer,
  348.                                                         name="conv" + str(depth * 6 + 7)))
  349.  
  350.     D_residual_blocks[-1] = tf.layers.batch_normalization(inputs=D_residual_blocks[-1], training=training)
  351.     D_residual_blocks[-1] = tf.nn.relu(D_residual_blocks[-1])
  352.     tmp = D_residual_blocks[-1]
  353.  
  354.     for i in range(depth):
  355.         tmp3 = tmp
  356.         for j in range(3):
  357.             tmp2 = tmp
  358.             D_residual_blocks.append(tf.layers.conv2d(inputs=tmp,
  359.                                                       filters=128,
  360.                                                       kernel_size=[3, 3],
  361.                                                       strides=(1, 1),
  362.                                                       padding="same",
  363.                                                       kernel_regularizer=regularizer,
  364.                                                       kernel_initializer=initializer,
  365.                                                       name="conv" + str(6 * i + 2 * j + depth * 6 + 8)))
  366.  
  367.             D_residual_blocks[-1] = tf.layers.batch_normalization(inputs=D_residual_blocks[-1], training=training)
  368.             D_residual_blocks[-1] = tf.nn.relu(D_residual_blocks[-1])
  369.             tmp = D_residual_blocks[-1]
  370.             D_residual_blocks.append(tf.layers.conv2d(inputs=tmp,
  371.                                                       filters=128,
  372.                                                       kernel_size=[3, 3],
  373.                                                       strides=(1, 1),
  374.                                                       padding="same",
  375.                                                       kernel_regularizer=regularizer,
  376.                                                       kernel_initializer=initializer,
  377.                                                       name="conv" + str(6 * i + 2 * j + depth * 6 + 9)))
  378.  
  379.             tmp = D_residual_blocks[-1] + tmp2
  380.         tmp = tmp3 + tmp
  381.  
  382.     tmp2 = tmp
  383.     D_residual_blocks.append(tf.layers.conv2d(inputs=tmp,
  384.                                               filters=128,
  385.                                               kernel_size=[3, 3],
  386.                                               strides=(1, 1),
  387.                                               padding="same",
  388.                                               kernel_regularizer=regularizer,
  389.                                               kernel_initializer=initializer,
  390.                                               name="conv" + str(depth * 14 + 4)))
  391.  
  392.     D_residual_blocks[-1] = tf.layers.batch_normalization(inputs=D_residual_blocks[-1], training=training)
  393.     D_residual_blocks[-1] = tf.nn.relu(D_residual_blocks[-1])
  394.     tmp = D_residual_blocks[-1]
  395.     D_residual_blocks.append(tf.layers.conv2d(inputs=tmp,
  396.                                               filters=128,
  397.                                               kernel_size=[3, 3],
  398.                                               strides=(1, 1),
  399.                                               padding="same",
  400.                                               kernel_regularizer=regularizer,
  401.                                               kernel_initializer=initializer,
  402.                                               name="conv" + str(depth * 14 + 5)))
  403.  
  404.     tmp = D_residual_blocks[-1] + tmp2 + D_residual_blocks[0]
  405.  
  406.     deconv1 = tf.layers.conv2d_transpose(inputs=tmp,
  407.                                          filters=64,
  408.                                          kernel_size=[5, 5],
  409.                                          strides=(2, 2),
  410.                                          padding="same",
  411.                                          kernel_regularizer=regularizer,
  412.                                          kernel_initializer=initializer,
  413.                                          name="deconv1")
  414.  
  415.     deconv1 = tf.layers.batch_normalization(inputs=deconv1, training=training)
  416.     deconv1 = tf.nn.relu(deconv1)
  417.     deconv2 = tf.layers.conv2d_transpose(inputs=deconv1,
  418.                                          filters=3,
  419.                                          kernel_size=[5, 5],
  420.                                          strides=(2, 2),
  421.                                          padding="same",
  422.                                          kernel_regularizer=regularizer,
  423.                                          kernel_initializer=initializer,
  424.                                          name="deconv2")
  425.  
  426.     # deconv2 = tf.nn.sigmoid(deconv2) # images must be between 0 and 1
  427.  
  428.     # Output Decoder
  429.  
  430. x_hat = tf.minimum(tf.maximum(deconv2, 0.), 1.)  # bounded ReLu
  431.  
  432. relu_err = tf.reduce_mean(tf.abs(x_hat - deconv2), axis=[0, 1, 2, 3])
  433.  
  434. tf.summary.scalar("Bounded_RELU_error", relu_err)
  435.  
  436. # Denormalize Reconstructed Image
  437.  
  438. # x_hat_norm = x_hat * tf.sqrt(var + 1e-10) + mean
  439. # x_hat_norm = tf.clip_by_value(x_hat_norm, 0, 1.0)
  440.  
  441. tf.summary.image("x", x, 5)
  442. tf.summary.image("x_hat", x_hat, 5)
  443.  
  444. # Distortion rate index
  445. # msssim_indexR = MSSSIM.tf_ms_ssim(x[:, :, :, 0:1], x_hat[:, :, :, 0:1])
  446. # msssim_indexG = MSSSIM.tf_ms_ssim(x[:, :, :, 1:2], x_hat[:, :, :, 1:2])
  447. # msssim_indexB = MSSSIM.tf_ms_ssim(x[:, :, :, 2:3], x_hat[:, :, :, 2:3])
  448. #
  449. # acc = (msssim_indexR + msssim_indexG + msssim_indexB) / 3.
  450.  
  451. acc = tf.reduce_mean(tf.image.ssim_multiscale(x, x_hat, 1.))
  452.  
  453. distortion = (1. - acc)
  454.  
  455. mse = tf.reduce_mean(tf.squared_difference(x, x_hat))
  456.  
  457. # loss = tf.where(tf.is_nan(distortion), mse, distortion)
  458. # loss = mse
  459. loss = distortion
  460.  
  461. tf.summary.scalar('accuracy', acc * 100.)
  462. tf.summary.scalar('loss', loss)
  463.  
  464. # Optimizer Context Model
  465. optimizer = tf.train.AdamOptimizer(learning_rate=lr)
  466.  
  467. update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
  468. with tf.control_dependencies(update_ops):
  469.     train = optimizer.minimize(loss)
  470.  
  471. # Graph initialization --------------------------------------------------------------------------------------------------------------
  472.  
  473. training_init_op = iterator.make_initializer(training_dataset)
  474.  
  475. sess = tf.Session()
  476.  
  477. merged = tf.summary.merge_all()
  478. train_writer = tf.summary.FileWriter('log/train', sess.graph)
  479.  
  480. init1 = tf.global_variables_initializer()
  481. init2 = tf.local_variables_initializer()
  482. sess.run(init1)
  483. sess.run(init2)
  484.  
  485. saver = tf.train.Saver()
  486.  
  487. # Model Training ------------------------------------------------------------------------------------------------------------------
  488.  
  489. update = 0
  490. sess.run(training_init_op)
  491. learning_rate = lr0
  492. for update in range(n_update):
  493.  
  494.     _, summary = sess.run((train, merged), feed_dict={training: True, lr: learning_rate})
  495.  
  496.     train_writer.add_summary(summary, update)
  497.  
  498.     if update % 40000 == 39999:
  499.         learning_rate *= 0.1
  500.  
  501. try:
  502.     saver.save(sess, "model/model.ckpt")
  503.     print("model saved successfully")
  504. except Exception:
  505.     pass
  506.  
  507. #
  508. # for i in range(num_batch):
  509. #     fn, batch_img_out, batch_img = sess.run((filenames, x_hat_norm, x), feed_dict={training: True})
  510. #
  511. #     for j in range(len(fn)):
  512. #         name = "/mnt/disks/disk2/ae_out/label/" + str(fn[j])[2:-1]
  513. #         plt.imsave(name, batch_img[j])
  514. #
  515. #         name = "/mnt/disks/disk2/ae_out/in/" + str(fn[j])[2:-1]
  516. #         plt.imsave(name, batch_img_out[j])
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement