Guest User

Adapted from model.py

a guest
Oct 24th, 2016
155
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
Python 9.29 KB | None | 0 0
  1. # Copyright 2016 Google Inc. All Rights Reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. #     http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ==============================================================================
  15. """Builds the MNIST network.
  16.  
  17. Implements factory method create_model(). The function creates class
  18. implementing MNIST specific implementations of build_train_graph(),
  19. build_eval_graph(), build_prediction_graph() and format_metric_values().
  20. """
  21.  
  22. import argparse
  23. import json
  24. import logging
  25.  
  26. import tensorflow as tf
  27. from tensorflow.contrib import layers
  28. from tensorflow.contrib.metrics.python.ops import metric_ops
  29. import util
  30. from util import override_if_not_in_args
  31. import numpy as np
  32.  
  33. sess = tf.Session()
  34.  
  35. # Hyper-parameters
  36. HIDDEN1 = 128  # Number of units in hidden layer 1.
  37. HIDDEN2 = 32  # Number of units in hidden layer 2.
  38.  
  39. # The MNIST dataset has 10 classes, representing the digits 0 through 9.
  40. NUM_CLASSES = 10
  41.  
  42. # The MNIST images are always 28x28 pixels.
  43. IMAGE_SIZE = 28
  44. IMAGE_PIXELS = IMAGE_SIZE * IMAGE_SIZE
  45.  
  46.  
  47. def create_model():
  48.   """Factory method that creates model to be used by generic task.py."""
  49.   parser = argparse.ArgumentParser()
  50.   parser.add_argument('--learning_rate', type=float, default=0.01)
  51.   args, task_args = parser.parse_known_args()
  52.  
  53.   override_if_not_in_args('--max_steps', '5000', task_args)
  54.   override_if_not_in_args('--batch_size', '100', task_args)
  55.   override_if_not_in_args('--eval_set_size', '10000', task_args)
  56.   override_if_not_in_args('--eval_interval_secs', '1', task_args)
  57.   override_if_not_in_args('--log_interval_secs', '1', task_args)
  58.   override_if_not_in_args('--min_train_eval_rate', '1', task_args)
  59.  
  60.   return Model(args.learning_rate, HIDDEN1, HIDDEN2), task_args
  61.  
  62.  
  63. class GraphReferences(object):
  64.   """Holder of base tensors used for training model using common task."""
  65.  
  66.   def __init__(self):
  67.     self.examples = None
  68.     self.train = None
  69.     self.global_step = None
  70.     self.metric_updates = []
  71.     self.metric_values = []
  72.     self.keys = None
  73.     self.predictions = []
  74.  
  75.  
  76. class Model(object):
  77.   """TensorFlow model for the MNIST problem."""
  78.  
  79.   def __init__(self, learning_rate, hidden1, hidden2):
  80.     self.learning_rate = learning_rate
  81.     self.hidden1 = hidden1
  82.     self.hidden2 = hidden2
  83.  
  84.   def build_graph(self, data_paths, batch_size, is_training):
  85.     """Builds generic graph for training or eval."""
  86.     tensors = GraphReferences()
  87.  
  88.     _, tensors.examples = util.read_examples(
  89.         data_paths,
  90.         batch_size,
  91.         shuffle=is_training,
  92.         num_epochs=None if is_training else 2)
  93.  
  94.     parsed = parse_examples(tensors.examples)
  95.  
  96.     # Build a Graph that computes predictions from the inference model.
  97.     features = tf.string_to_number(parsed['features'], tf.float32)
  98.     #features = tf.to_float(parsed['features'])
  99.     logits = inference(features, self.hidden1, self.hidden2)
  100.  
  101.     # Add to the Graph the Ops for loss calculation.
  102.     loss_value = loss(logits, parsed['labels'])
  103.  
  104.     # Add to the Graph the Ops for accuracy calculation.
  105.     accuracy_value = evaluation(logits, parsed['labels'])
  106.  
  107.     # Add to the Graph the Ops that calculate and apply gradients.
  108.     if is_training:
  109.       tensors.train, tensors.global_step = training(loss_value,
  110.                                                     self.learning_rate)
  111.     else:
  112.       tensors.global_step = tf.Variable(0, name='global_step', trainable=False)
  113.  
  114.     # Add streaming means.
  115.     loss_op, loss_update = metric_ops.streaming_mean(loss_value)
  116.     accuracy_op, accuracy_update = metric_ops.streaming_mean(accuracy_value)
  117.  
  118.     tf.scalar_summary('accuracy', accuracy_op)
  119.     tf.scalar_summary('loss', loss_op)
  120.  
  121.     tensors.metric_updates = [loss_update, accuracy_update]
  122.     tensors.metric_values = [loss_op, accuracy_op]
  123.     return tensors
  124.  
  125.   def build_train_graph(self, data_paths, batch_size):
  126.     return self.build_graph(data_paths, batch_size, is_training=True)
  127.  
  128.   def build_eval_graph(self, data_paths, batch_size):
  129.     return self.build_graph(data_paths, batch_size, is_training=False)
  130.  
  131.   def build_prediction_graph(self, export_dir):
  132.     """Builds prediction graph and registers appropriate endpoints."""
  133.     logging.info('Exporting prediction graph to %s', export_dir)
  134.     examples = tf.placeholder(tf.string, shape=(None,))
  135.     features = {
  136.         'image': tf.FixedLenFeature(
  137.             shape=[IMAGE_PIXELS], dtype=tf.float32),
  138.         'key': tf.FixedLenFeature(
  139.             shape=[], dtype=tf.string),
  140.     }
  141.  
  142.     parsed = tf.parse_example(examples, features)
  143.     images = parsed['image']
  144.     keys = parsed['key']
  145.  
  146.     # Build a Graph that computes predictions from the inference model.
  147.     logits = inference(images, self.hidden1, self.hidden2)
  148.     softmax = tf.nn.softmax(logits)
  149.     prediction = tf.argmax(softmax, 1)
  150.  
  151.     # Mark the inputs and the outputs
  152.     # Marking the input tensor with an alias with suffix _bytes. This is to
  153.     # indicate that this tensor value is raw bytes and will be base64 encoded
  154.     # over HTTP.
  155.     # Note that any output tensor marked with an alias with suffix _bytes, shall
  156.     # be base64 encoded in the HTTP response. To get the binary value, it
  157.     # should be base64 decoded.
  158.     tf.add_to_collection('inputs',
  159.                          json.dumps({'examples_bytes': examples.name}))
  160.     tf.add_to_collection('outputs', json.dumps({
  161.         'key': keys.name,
  162.         'prediction': prediction.name,
  163.         'scores': softmax.name
  164.     }))
  165.  
  166.   def format_metric_values(self, metric_values):
  167.     """Formats metric values - used for logging purpose."""
  168.     return 'loss: %.3f, accuracy: %.3f' % (metric_values[0], metric_values[1])
  169.  
  170.   def format_prediction_values(self, prediction):
  171.     """Formats prediction values - used for writing batch predictions as csv."""
  172.     return '%.3f' % (prediction[0])
  173.  
  174.  
  175. def parse_examples(examples):
  176.   feature_map = {
  177.       'labels': tf.FixedLenFeature(
  178.           shape=[], dtype=tf.int64, default_value=[-1]),
  179.       'features': tf.FixedLenFeature(
  180.           shape=[], dtype=tf.string),
  181.   }
  182.   return tf.parse_example(examples, features=feature_map)
  183.  
  184.  
  185. def inference(images, hidden1_units, hidden2_units):
  186.   """Build the MNIST model up to where it may be used for inference.
  187.  
  188.  Args:
  189.    images: Images placeholder, from inputs().
  190.    hidden1_units: Size of the first hidden layer.
  191.    hidden2_units: Size of the second hidden layer.
  192.  Returns:
  193.    softmax_linear: Output tensor with the computed logits.
  194.  """
  195.   hidden1 = layers.fully_connected(images, hidden1_units)
  196.   hidden2 = layers.fully_connected(hidden1, hidden2_units)
  197.   return layers.fully_connected(hidden2, NUM_CLASSES)
  198.  
  199.  
  200. def loss(logits, labels):
  201.   """Calculates the loss from the logits and the labels.
  202.  
  203.  Args:
  204.    logits: Logits tensor, float - [batch_size, NUM_CLASSES].
  205.    labels: Labels tensor, int32 - [batch_size].
  206.  Returns:
  207.    loss: Loss tensor of type float.
  208.  """
  209.   labels = tf.to_int64(labels)
  210.   cross_entropy = tf.nn.sparse_softmax_cross_entropy_with_logits(
  211.       logits, labels, name='xentropy')
  212.   return tf.reduce_mean(cross_entropy, name='xentropy_mean')
  213.  
  214.  
  215. def training(loss_op, learning_rate):
  216.   """Sets up the training Ops.
  217.  
  218.  Creates a summarizer to track the loss over time in TensorBoard.
  219.  Creates an optimizer and applies the gradients to all trainable variables.
  220.  The Op returned by this function is what must be passed to the
  221.  `sess.run()` call to cause the model to train.
  222.  Args:
  223.    loss_op: Loss tensor, from loss().
  224.    learning_rate: The learning rate to use for gradient descent.
  225.  Returns:
  226.    A pair consisting of the Op for training and the global step.
  227.  """
  228.   # Create the gradient descent optimizer with the given learning rate.
  229.   optimizer = tf.train.GradientDescentOptimizer(learning_rate)
  230.   # Create a variable to track the global step.
  231.   global_step = tf.Variable(0, name='global_step', trainable=False)
  232.   # Use the optimizer to apply the gradients that minimize the loss
  233.   # (and also increment the global step counter) as a single training step.
  234.   train_op = optimizer.minimize(loss_op, global_step=global_step)
  235.   return train_op, global_step
  236.  
  237.  
  238. def evaluation(logits, labels):
  239.   """Evaluate the quality of the logits at predicting the label.
  240.  
  241.  Args:
  242.    logits: Logits tensor, float - [batch_size, NUM_CLASSES].
  243.    labels: Labels tensor, int32 - [batch_size], with values in the
  244.      range [0, NUM_CLASSES).
  245.  Returns:
  246.    A scalar float tensor with the ratio of examples (out of batch_size)
  247.    that were predicted correctly.
  248.  """
  249.   # For a classifier model, we can use the in_top_k Op.
  250.   # It returns a bool tensor with shape [batch_size] that is true for
  251.   # the examples where the label is in the top k (here k=1)
  252.   # of all logits for that example.
  253.   correct = tf.nn.in_top_k(logits, labels, 1)
  254.   return tf.reduce_mean(tf.cast(correct, tf.float32))
Add Comment
Please, Sign In to add comment