Advertisement
Not a member of Pastebin yet?
Sign Up,
it unlocks many cool features!
- from __future__ import division, print_function
- import numpy as np
- import tensorflow as tf
- from tensorflow.examples.tutorials.mnist import input_data
- from tensorflow.contrib import data
- ### the function of interest
- def per_example_gradients(loss, variables, grad_ys=None):
- def rewrite_grad(grad):
- for check, handler in handlers:
- if check(grad.op):
- return handler(grad.op)
- else:
- raise ValueError("Can't handle op: {}".format(grad.op))
- grads = tf.gradients(loss, variables, grad_ys=grad_ys)
- return [rewrite_grad(grad) for grad in grads]
- # we have to do some local graph inspection and dispatch based on op
- handlers = []
- def register_handler(check, handler):
- handlers.append((check, handler))
- def handle_matmul(op):
- inputs, output_grads = op.inputs
- return inputs[..., None] * output_grads[:, None, :]
- register_handler(lambda op: op.type == 'MatMul', handle_matmul)
- def handle_bias_add(op):
- return op.inputs[0].op.inputs[0]
- register_handler(lambda op: op.type == 'Reshape' and op.inputs[0].op.type == 'Sum',
- handle_bias_add)
- ### example
- def predict(params, inputs):
- for W, b in params:
- outputs = tf.matmul(inputs, W) + b
- inputs = tf.tanh(outputs)
- return outputs
- def make_loss(params, batch):
- inputs, labels = batch
- logits = predict(params, inputs)
- return tf.nn.softmax_cross_entropy_with_logits(
- logits=logits, labels=labels)
- def rand_params(layer_sizes):
- def randn(*shape):
- return tf.Variable(tf.random_normal(shape, stddev=1e-2))
- return [(randn(m, n), randn(n))
- for m, n in zip(layer_sizes[:-1], layer_sizes[1:])]
- if __name__ == '__main__':
- # setup
- layer_sizes = [28*28, 100, 100, 10]
- batch_size = 32
- # load dataset
- mnist = input_data.read_data_sets('MNIST_data', one_hot=True)
- train_data = data.Dataset.from_tensor_slices(
- (mnist.train.images, mnist.train.labels))
- batch = train_data.batch(batch_size).make_one_shot_iterator()
- # initialize parameters and set up loss function
- params = rand_params(layer_sizes)
- loss = make_loss(params, batch.get_next())
- # compute per-example gradients
- variables = [elt for pair in params for elt in pair]
- grads = per_example_gradients(loss, variables)
- # run a session to check the shapes
- init_op = tf.global_variables_initializer()
- with tf.Session() as sess:
- sess.run(init_op)
- print(map(np.shape, sess.run(grads)))
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement