Guest User

Untitled

a guest
Feb 21st, 2018
92
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
text 1.99 KB | None | 0 0
  1. import os
  2. os.environ["TF_CPP_MIN_LOG_LEVEL"]="2"
  3.  
  4. import tensorflow as tf
  5. import numpy as np
  6.  
  7. ## TODO
  8. ## - Move learning rate
  9. ## - Save model, restore model
  10. class NNModel():
  11. def __init__(self, layers, num_classes):
  12. print("In NNModel constructor")
  13. self.layers = layers
  14. self.num_classes = num_classes
  15. self.learning_rate = 0.1
  16. self.model = tf.estimator.Estimator(self.build)
  17.  
  18. def build_layers(self, dictionary):
  19. tmp = dictionary["x"]
  20. for layer in self.layers:
  21. tmp = tf.layers.dense(tmp, layer)
  22. tmp = tf.layers.dense(tmp, self.num_classes)
  23. return tmp
  24.  
  25. def train(self, input_fn, num_steps):
  26. self.model.train(input_fn, steps = num_steps)
  27.  
  28. def evaluate(self, input_fn):
  29. return self.model.evaluate(input_fn)
  30.  
  31. def predict(self, input_fn):
  32. return self.model.predict(input_fn)
  33.  
  34. def build(self, features, labels, mode):
  35. logits = self.build_layers(features)
  36.  
  37. pred_classes = tf.argmax(logits, axis=1)
  38. pred_probas = tf.nn.softmax(logits)
  39.  
  40. if mode == tf.estimator.ModeKeys.PREDICT:
  41. return tf.estimator.EstimatorSpec(mode, predictions=pred_classes)
  42.  
  43. loss_op = tf.reduce_mean(tf.nn.sparse_softmax_cross_entropy_with_logits(
  44. logits=logits, labels=tf.cast(labels, dtype=tf.int32)))
  45. optimizer = tf.train.GradientDescentOptimizer(learning_rate=self.learning_rate)
  46. train_op = optimizer.minimize(loss_op, global_step=tf.train.get_global_step())
  47.  
  48. # Evaluate the accuracy of the model
  49. acc_op = tf.metrics.accuracy(labels=labels, predictions=pred_classes)
  50.  
  51. # TF Estimators requires to return a EstimatorSpec, that specify
  52. # the different ops for training, evaluating, ...
  53. estim_specs = tf.estimator.EstimatorSpec(
  54. mode=mode,
  55. predictions=pred_classes,
  56. loss=loss_op,
  57. train_op=train_op,
  58. eval_metric_ops={'accuracy': acc_op})
  59. return estim_specs
Add Comment
Please, Sign In to add comment