Advertisement
Not a member of Pastebin yet?
Sign Up,
it unlocks many cool features!
- pickle.dump(predictions, open("predictions.p", "wb"))
- pickle.dump(history, open("history.p", "wb"))
- tf.train.write_graph(sess.graph_def, '.', './checkpoint/har.pbtxt')
- saver.save(sess, save_path = "./checkpoint/har.ckpt")
- sess.close()
- history = pickle.load(open("history.p", "rb"))
- predictions = pickle.load(open("predictions.p", "rb"))
- plt.figure(figsize=(12, 8))
- plt.plot(np.array(history['train_loss']), "r--", label="Train loss")
- plt.plot(np.array(history['train_acc']), "g--", label="Train accuracy")
- plt.plot(np.array(history['test_loss']), "r-", label="Test loss")
- plt.plot(np.array(history['test_acc']), "g-", label="Test accuracy")
- plt.title("Training session's progress over iterations")
- plt.legend(loc='upper right', shadow=True)
- plt.ylabel('Training Progress (Loss or Accuracy values)')
- plt.xlabel('Training Epoch')
- plt.ylim(0)
- plt.show()
- LABELS = ['Walk', 'ObstacleWalk', 'Kick', 'Lift']
- max_test = np.argmax(y_test, axis=1)
- max_predictions = np.argmax(predictions, axis=1)
- confusion_matrix = metrics.confusion_matrix(max_test, max_predictions)
- plt.figure(figsize=(16, 14))
- sns.heatmap(confusion_matrix, xticklabels=LABELS, yticklabels=LABELS, annot=True, fmt="d");
- plt.title("Confusion matrix")
- plt.ylabel('True label')
- plt.xlabel('Predicted label')
- plt.show()
- from tensorflow.python.tools import freeze_graph
- MODEL_NAME = 'har'
- input_graph_path = 'checkpoint/' + MODEL_NAME+'.pbtxt'
- checkpoint_path = './checkpoint/' +MODEL_NAME+'.ckpt'
- restore_op_name = "save/restore_all"
- filename_tensor_name = "save/Const:0"
- output_frozen_graph_name = 'frozen_'+MODEL_NAME+'.pb'
- freeze_graph.freeze_graph(input_graph_path, input_saver="",
- input_binary=False, input_checkpoint=checkpoint_path,
- output_node_names="y_", restore_op_name="save/restore_all",
- filename_tensor_name="save/Const:0",
- output_graph=output_frozen_graph_name, clear_devices=True, initializer_nodes="")
- print("Final Accuracy: ", acc_final)
- print("Final Loss: ", loss_final)
- print("Precision: ", metrics.precision_score(max_test, max_predictions))
- print("Recall: ", metrics.recall_score(max_test, max_predictions))
- print("F1 Score: ", metrics.f1_score(max_test, max_predictions))
- frp, tpr, threshold = metrics.roc_curve(max_test, max_predictions)
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement