Guest User

Untitled

a guest
Jul 21st, 2018
88
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
text 4.18 KB | None | 0 0
  1. {
  2. "cells": [
  3. {
  4. "cell_type": "code",
  5. "execution_count": null,
  6. "metadata": {
  7. "collapsed": false
  8. },
  9. "outputs": [],
  10. "source": [
  11. "%matplotlib inline"
  12. ]
  13. },
  14. {
  15. "cell_type": "markdown",
  16. "metadata": {},
  17. "source": [
  18. "\n# Confusion matrix\n\n\nExample of confusion matrix usage to evaluate the quality\nof the output of a classifier on the iris data set. The\ndiagonal elements represent the number of points for which\nthe predicted label is equal to the true label, while\noff-diagonal elements are those that are mislabeled by the\nclassifier. The higher the diagonal values of the confusion\nmatrix the better, indicating many correct predictions.\n\nThe figures show the confusion matrix with and without\nnormalization by class support size (number of elements\nin each class). This kind of normalization can be\ninteresting in case of class imbalance to have a more\nvisual interpretation of which class is being misclassified.\n\nHere the results are not as good as they could be as our\nchoice for the regularization parameter C was not the best.\nIn real life applications this parameter is usually chosen\nusing `grid_search`.\n\n\n"
  19. ]
  20. },
  21. {
  22. "cell_type": "code",
  23. "execution_count": null,
  24. "metadata": {
  25. "collapsed": false
  26. },
  27. "outputs": [],
  28. "source": [
  29. "print(__doc__)\n\nimport itertools\nimport numpy as np\nimport matplotlib.pyplot as plt\n\nfrom sklearn import svm, datasets\nfrom sklearn.model_selection import train_test_split\nfrom sklearn.metrics import confusion_matrix\n\n# import some data to play with\niris = datasets.load_iris()\nX = iris.data\ny = iris.target\nclass_names = iris.target_names\n\n# Split the data into a training set and a test set\nX_train, X_test, y_train, y_test = train_test_split(X, y, random_state=0)\n\n# Run classifier, using a model that is too regularized (C too low) to see\n# the impact on the results\nclassifier = svm.SVC(kernel='linear', C=0.01)\ny_pred = classifier.fit(X_train, y_train).predict(X_test)\n\n\ndef plot_confusion_matrix(cm, classes,\n normalize=False,\n title='Confusion matrix',\n cmap=plt.cm.Blues):\n \"\"\"\n This function prints and plots the confusion matrix.\n Normalization can be applied by setting `normalize=True`.\n \"\"\"\n if normalize:\n cm = cm.astype('float') / cm.sum(axis=1)[:, np.newaxis]\n print(\"Normalized confusion matrix\")\n else:\n print('Confusion matrix, without normalization')\n\n print(cm)\n\n plt.imshow(cm, interpolation='nearest', cmap=cmap)\n plt.title(title)\n plt.colorbar()\n tick_marks = np.arange(len(classes))\n plt.xticks(tick_marks, classes, rotation=45)\n plt.yticks(tick_marks, classes)\n\n fmt = '.2f' if normalize else 'd'\n thresh = cm.max() / 2.\n for i, j in itertools.product(range(cm.shape[0]), range(cm.shape[1])):\n plt.text(j, i, format(cm[i, j], fmt),\n horizontalalignment=\"center\",\n color=\"white\" if cm[i, j] > thresh else \"black\")\n\n plt.tight_layout()\n plt.ylabel('True label')\n plt.xlabel('Predicted label')\n\n# Compute confusion matrix\ncnf_matrix = confusion_matrix(y_test, y_pred)\nnp.set_printoptions(precision=2)\n\n# Plot non-normalized confusion matrix\nplt.figure()\nplot_confusion_matrix(cnf_matrix, classes=class_names,\n title='Confusion matrix, without normalization')\n\n# Plot normalized confusion matrix\nplt.figure()\nplot_confusion_matrix(cnf_matrix, classes=class_names, normalize=True,\n title='Normalized confusion matrix')\n\nplt.show()"
  30. ]
  31. }
  32. ],
  33. "metadata": {
  34. "kernelspec": {
  35. "display_name": "Python 3",
  36. "language": "python",
  37. "name": "python3"
  38. },
  39. "language_info": {
  40. "codemirror_mode": {
  41. "name": "ipython",
  42. "version": 3
  43. },
  44. "file_extension": ".py",
  45. "mimetype": "text/x-python",
  46. "name": "python",
  47. "nbconvert_exporter": "python",
  48. "pygments_lexer": "ipython3",
  49. "version": "3.6.6"
  50. }
  51. },
  52. "nbformat": 4,
  53. "nbformat_minor": 0
  54. }
Add Comment
Please, Sign In to add comment