Guest User

Untitled

a guest
Jun 21st, 2018
72
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
text 2.74 KB | None | 0 0
  1. from __future__ import absolute_import
  2. from __future__ import division
  3. from __future__ import print_function
  4.  
  5. import os
  6.  
  7. import numpy as np
  8. import tensorflow as tf
  9.  
  10. tf.logging.set_verbosity(tf.logging.INFO)
  11.  
  12. # Data sets
  13. IRIS_TRAINING = os.path.join(os.path.dirname(__file__), "iris_training.csv")
  14. IRIS_TEST = os.path.join(os.path.dirname(__file__), "iris_test.csv")
  15.  
  16. def main(unused_argv):
  17. # Load datasets.
  18. training_set = tf.contrib.learn.datasets.base.load_csv_with_header(
  19. filename=IRIS_TRAINING, target_dtype=np.int, features_dtype=np.float32)
  20. test_set = tf.contrib.learn.datasets.base.load_csv_with_header(
  21. filename=IRIS_TEST, target_dtype=np.int, features_dtype=np.float32)
  22.  
  23. # Specify that all features have real-value data
  24. feature_columns = [tf.contrib.layers.real_valued_column("", dimension=4)]
  25.  
  26. # Build 3 layer DNN with 10, 20, 10 units respectively.
  27. classifier = tf.contrib.learn.DNNClassifier(feature_columns=feature_columns,
  28. hidden_units=[10, 20, 10],
  29. n_classes=3,
  30. model_dir="/tmp/iris_model",
  31. config= tf.contrib.learn.RunConfig(save_checkpoints_secs=1))
  32. validation_metrics = {
  33. "accuracy":
  34. tf.contrib.learn.MetricSpec(
  35. metric_fn=tf.contrib.metrics.streaming_accuracy,
  36. prediction_key=tf.contrib.learn.prediction_key.PredictionKey.
  37. CLASSES),
  38. "precision":
  39. tf.contrib.learn.MetricSpec(
  40. metric_fn=tf.contrib.metrics.streaming_precision,
  41. prediction_key=tf.contrib.learn.prediction_key.PredictionKey.
  42. CLASSES),
  43. "recall":
  44. tf.contrib.learn.MetricSpec(
  45. metric_fn=tf.contrib.metrics.streaming_recall,
  46. prediction_key=tf.contrib.learn.prediction_key.PredictionKey.
  47. CLASSES)
  48. }
  49.  
  50.  
  51.  
  52. validation_monitor = tf.contrib.learn.monitors.ValidationMonitor(
  53. test_set.data,
  54. test_set.target,
  55. every_n_steps=50,
  56. metrics = validation_metrics)
  57.  
  58. # Fit model.
  59. classifier.fit(x=training_set.data,
  60. y=training_set.target,
  61. steps=2000,
  62. monitors=[validation_monitor])
  63.  
  64. # Evaluate accuracy.
  65. accuracy_score = classifier.evaluate(x=test_set.data,
  66. y=test_set.target)["accuracy"]
  67. print('Accuracy: {0:f}'.format(accuracy_score))
  68.  
  69. # Classify two new flower samples.
  70. new_samples = np.array(
  71. [[6.4, 3.2, 4.5, 1.5], [5.8, 3.1, 5.0, 1.7]], dtype=float)
  72. y = list(classifier.predict(new_samples, as_iterable=True))
  73. print('Predictions: {}'.format(str(y)))
  74.  
  75. if __name__ == "__main__":
  76. tf.app.run()
Add Comment
Please, Sign In to add comment