Advertisement
Guest User

Untitled

a guest
Apr 22nd, 2019
257
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
text 2.33 KB | None | 0 0
  1. pickle.dump(predictions, open("predictions.p", "wb"))
  2. pickle.dump(history, open("history.p", "wb"))
  3. tf.train.write_graph(sess.graph_def, '.', './checkpoint/har.pbtxt')
  4. saver.save(sess, save_path = "./checkpoint/har.ckpt")
  5. sess.close()
  6.  
  7. history = pickle.load(open("history.p", "rb"))
  8. predictions = pickle.load(open("predictions.p", "rb"))
  9.  
  10.  
  11. plt.figure(figsize=(12, 8))
  12. plt.plot(np.array(history['train_loss']), "r--", label="Train loss")
  13. plt.plot(np.array(history['train_acc']), "g--", label="Train accuracy")
  14. plt.plot(np.array(history['test_loss']), "r-", label="Test loss")
  15. plt.plot(np.array(history['test_acc']), "g-", label="Test accuracy")
  16. plt.title("Training session's progress over iterations")
  17. plt.legend(loc='upper right', shadow=True)
  18. plt.ylabel('Training Progress (Loss or Accuracy values)')
  19. plt.xlabel('Training Epoch')
  20. plt.ylim(0)
  21. plt.show()
  22.  
  23.  
  24. LABELS = ['Walk', 'ObstacleWalk', 'Kick', 'Lift']
  25. max_test = np.argmax(y_test, axis=1)
  26. max_predictions = np.argmax(predictions, axis=1)
  27. confusion_matrix = metrics.confusion_matrix(max_test, max_predictions)
  28. plt.figure(figsize=(16, 14))
  29. sns.heatmap(confusion_matrix, xticklabels=LABELS, yticklabels=LABELS, annot=True, fmt="d");
  30. plt.title("Confusion matrix")
  31. plt.ylabel('True label')
  32. plt.xlabel('Predicted label')
  33. plt.show()
  34.  
  35. from tensorflow.python.tools import freeze_graph
  36.  
  37. MODEL_NAME = 'har'
  38.  
  39. input_graph_path = 'checkpoint/' + MODEL_NAME+'.pbtxt'
  40. checkpoint_path = './checkpoint/' +MODEL_NAME+'.ckpt'
  41. restore_op_name = "save/restore_all"
  42. filename_tensor_name = "save/Const:0"
  43. output_frozen_graph_name = 'frozen_'+MODEL_NAME+'.pb'
  44.  
  45. freeze_graph.freeze_graph(input_graph_path, input_saver="",
  46. input_binary=False, input_checkpoint=checkpoint_path,
  47. output_node_names="y_", restore_op_name="save/restore_all",
  48. filename_tensor_name="save/Const:0",
  49. output_graph=output_frozen_graph_name, clear_devices=True, initializer_nodes="")
  50.  
  51. print("Final Accuracy: ", acc_final)
  52. print("Final Loss: ", loss_final)
  53. print("Precision: ", metrics.precision_score(max_test, max_predictions))
  54. print("Recall: ", metrics.recall_score(max_test, max_predictions))
  55. print("F1 Score: ", metrics.f1_score(max_test, max_predictions))
  56.  
  57. frp, tpr, threshold = metrics.roc_curve(max_test, max_predictions)
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement