Advertisement
Guest User

Untitled

a guest
Aug 19th, 2017
59
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
text 2.36 KB | None | 0 0
  1. from __future__ import division, print_function
  2. import numpy as np
  3. import tensorflow as tf
  4. from tensorflow.examples.tutorials.mnist import input_data
  5. from tensorflow.contrib import data
  6.  
  7. ### the function of interest
  8.  
  9. def per_example_gradients(loss, variables, grad_ys=None):
  10. def rewrite_grad(grad):
  11. for check, handler in handlers:
  12. if check(grad.op):
  13. return handler(grad.op)
  14. else:
  15. raise ValueError("Can't handle op: {}".format(grad.op))
  16. grads = tf.gradients(loss, variables, grad_ys=grad_ys)
  17. return [rewrite_grad(grad) for grad in grads]
  18.  
  19. # we have to do some local graph inspection and dispatch based on op
  20.  
  21. handlers = []
  22. def register_handler(check, handler):
  23. handlers.append((check, handler))
  24.  
  25. def handle_matmul(op):
  26. inputs, output_grads = op.inputs
  27. return inputs[..., None] * output_grads[:, None, :]
  28. register_handler(lambda op: op.type == 'MatMul', handle_matmul)
  29.  
  30. def handle_bias_add(op):
  31. return op.inputs[0].op.inputs[0]
  32. register_handler(lambda op: op.type == 'Reshape' and op.inputs[0].op.type == 'Sum',
  33. handle_bias_add)
  34.  
  35.  
  36. ### example
  37.  
  38. def predict(params, inputs):
  39. for W, b in params:
  40. outputs = tf.matmul(inputs, W) + b
  41. inputs = tf.tanh(outputs)
  42. return outputs
  43.  
  44.  
  45. def make_loss(params, batch):
  46. inputs, labels = batch
  47. logits = predict(params, inputs)
  48. return tf.nn.softmax_cross_entropy_with_logits(
  49. logits=logits, labels=labels)
  50.  
  51.  
  52. def rand_params(layer_sizes):
  53. def randn(*shape):
  54. return tf.Variable(tf.random_normal(shape, stddev=1e-2))
  55. return [(randn(m, n), randn(n))
  56. for m, n in zip(layer_sizes[:-1], layer_sizes[1:])]
  57.  
  58.  
  59. if __name__ == '__main__':
  60. # setup
  61. layer_sizes = [28*28, 100, 100, 10]
  62. batch_size = 32
  63.  
  64. # load dataset
  65. mnist = input_data.read_data_sets('MNIST_data', one_hot=True)
  66. train_data = data.Dataset.from_tensor_slices(
  67. (mnist.train.images, mnist.train.labels))
  68. batch = train_data.batch(batch_size).make_one_shot_iterator()
  69.  
  70. # initialize parameters and set up loss function
  71. params = rand_params(layer_sizes)
  72. loss = make_loss(params, batch.get_next())
  73.  
  74. # compute per-example gradients
  75. variables = [elt for pair in params for elt in pair]
  76. grads = per_example_gradients(loss, variables)
  77.  
  78. # run a session to check the shapes
  79. init_op = tf.global_variables_initializer()
  80. with tf.Session() as sess:
  81. sess.run(init_op)
  82. print(map(np.shape, sess.run(grads)))
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement