Advertisement
wagner-cipriano

plot a pretty confusion matrix with seaborn py like matlab

Jul 3rd, 2018
280
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
Python 9.47 KB | None | 0 0
  1. # -*- coding: utf-8 -*-
  2. """
  3. plot a pretty confusion matrix with seaborn
  4. works in python 2 and 3
  5.  
  6. Created on Mon Jun 25 14:17:37 2018
  7. @author: Wagner Cipriano - wagnerbhbr-gmail
  8. REFerences:
  9.  https://www.mathworks.com/help/nnet/ref/plotconfusion.html
  10.    https://www.mathworks.com/help/examples/nnet/win64/PlotConfusionMatrixUsingCategoricalLabelsExample_02.png
  11.  https://stackoverflow.com/questions/28200786/how-to-plot-scikit-learn-classification-report
  12.  https://stackoverflow.com/questions/5821125/how-to-plot-confusion-matrix-with-string-axis-rather-than-integer-in-python
  13.  https://www.programcreek.com/python/example/96197/seaborn.heatmap
  14.  https://stackoverflow.com/questions/19233771/sklearn-plot-confusion-matrix-with-labels/31720054
  15. """
  16.  
  17. #imports
  18. from pandas import DataFrame
  19. import numpy as np
  20. import matplotlib.pyplot as plt
  21. import matplotlib.font_manager as fm
  22. from matplotlib.collections import QuadMesh
  23. import seaborn as sn
  24.  
  25.  
  26. def get_new_fig(fn, figsize=[9,9]):
  27.     ## Init graphics
  28.     fig1 = plt.figure(fn, figsize)
  29.     ax1 = fig1.gca()   #Get Current Axis
  30.     ax1.cla() # clear existing plot
  31.     return fig1, ax1
  32. #
  33.  
  34. def configcell_text_and_colors(array_df, lin, col, oText, facecolors, posi, fz, fmt, show_null_values=0):
  35.     """
  36.      config cell text and colors
  37.      and return text elements to add and to dell
  38.      @TODO: use fmt
  39.    """
  40.     text_add = []; text_del = [];
  41.     cell_val = array_df[lin][col]
  42.     tot_all = array_df[-1][-1]
  43.     per = (float(cell_val) / tot_all) * 100
  44.     curr_column = array_df[:,col]
  45.     ccl = len(curr_column)
  46.  
  47.     #last line  and/or last column
  48.     if(col == (ccl - 1)) or (lin == (ccl - 1)):
  49.         #tots and percents
  50.         if(cell_val != 0):
  51.             if(col == ccl - 1) and (lin == ccl - 1):
  52.                 tot_rig = 0
  53.                 for i in range(array_df.shape[0] - 1):
  54.                     tot_rig += array_df[i][i]
  55.                 per_ok = (float(tot_rig) / cell_val) * 100
  56.             elif(col == ccl - 1):
  57.                 tot_rig = array_df[lin][lin]
  58.                 per_ok = (float(tot_rig) / cell_val) * 100
  59.             elif(lin == ccl - 1):
  60.                 tot_rig = array_df[col][col]
  61.                 per_ok = (float(tot_rig) / cell_val) * 100
  62.             per_err = 100 - per_ok
  63.         else:
  64.             per_ok = per_err = 0
  65.  
  66.         per_ok_s = ['%.2f%%'%(per_ok), '100%'] [per_ok == 100]
  67.  
  68.         #text to DEL
  69.         text_del.append(oText)
  70.  
  71.         #text to ADD
  72.         font_prop = fm.FontProperties(weight='bold', size=fz)
  73.         text_kwargs = dict(color='w', ha="center", va="center", gid='sum', fontproperties=font_prop)
  74.         lis_txt = ['%d'%(cell_val), per_ok_s, '%.2f%%'%(per_err)]
  75.         lis_kwa = [text_kwargs]
  76.         dic = text_kwargs.copy(); dic['color'] = 'g'; lis_kwa.append(dic);
  77.         dic = text_kwargs.copy(); dic['color'] = 'r'; lis_kwa.append(dic);
  78.         lis_pos = [(oText._x, oText._y-0.3), (oText._x, oText._y), (oText._x, oText._y+0.3)]
  79.         for i in range(len(lis_txt)):
  80.             newText = dict(x=lis_pos[i][0], y=lis_pos[i][1], text=lis_txt[i], kw=lis_kwa[i])
  81.             #print 'lin: %s, col: %s, newText: %s' %(lin, col, newText)
  82.             text_add.append(newText)
  83.         #print '\n'
  84.  
  85.         #set background color for sum cells (last line and last column)
  86.         carr = [0.27, 0.30, 0.27, 1.0]
  87.         if(col == ccl - 1) and (lin == ccl - 1):
  88.             carr = [0.17, 0.20, 0.17, 1.0]
  89.         facecolors[posi] = carr
  90.  
  91.     else:
  92.         if(per > 0):
  93.             txt = '%s\n%.2f%%' %(cell_val, per)
  94.         else:
  95.             if(show_null_values == 0):
  96.                 txt = ''
  97.             elif(show_null_values == 1):
  98.                 txt = '0'
  99.             else:
  100.                 txt = '0\n0.0%'
  101.         oText.set_text(txt)
  102.  
  103.         #main diagonal
  104.         if(col == lin):
  105.             #set color of the textin the diagonal to white
  106.             oText.set_color('w')
  107.             # set background color in the diagonal to blue
  108.             facecolors[posi] = [0.35, 0.8, 0.55, 1.0]
  109.         else:
  110.             oText.set_color('r')
  111.  
  112.     return text_add, text_del
  113. #
  114.  
  115. def insert_totals(df_cm):
  116.     """ insert total column and line (the last ones) """
  117.     sum_col = []
  118.     for c in df_cm.columns:
  119.         sum_col.append( df_cm[c].sum() )
  120.     sum_lin = []
  121.     for item_line in df_cm.iterrows():
  122.         sum_lin.append( item_line[1].sum() )
  123.     df_cm['sum_lin'] = sum_lin
  124.     sum_col.append(np.sum(sum_lin))
  125.     df_cm.loc['sum_col'] = sum_col
  126.     #print ('\ndf_cm:\n', df_cm, '\n\b\n')
  127. #
  128.  
  129.  
  130.  
  131. def pretty_plot_confusion_matrix(df_cm, annot=True, cmap="Oranges", fmt='.2f', fz=11,
  132.       lw=0.5, cbar=False, figsize=[8,8], show_null_values=0):
  133.     """
  134.      print conf matrix with default layout (like matlab)
  135.      params:
  136.        df_cm   dataframe (pandas) without totals
  137.        annot   print text in each cell
  138.        cmap    Oranges,Oranges_r,YlGnBu,Blues,RdBu, ... see:
  139.        fz      fontsize
  140.        lw      linewidth
  141.    """
  142.     # create "Total" column
  143.     insert_totals(df_cm)
  144.  
  145.     #this is for print allways in the same window
  146.     fig, ax1 = get_new_fig('Conf matrix default', figsize)
  147.  
  148.     #thanks for seaborn
  149.     ax = sn.heatmap(df_cm, annot=annot, annot_kws={"size": fz}, linewidths=lw, ax=ax1,
  150.                     cbar=cbar, cmap=cmap, linecolor='w', fmt=fmt)
  151.  
  152.     #set ticklabels rotation
  153.     ax.set_xticklabels(ax.get_xticklabels(), rotation = 45, fontsize = 10)
  154.     ax.set_yticklabels(ax.get_yticklabels(), rotation = 25, fontsize = 10)
  155.  
  156.     # Turn off all the ticks
  157.     for t in ax.xaxis.get_major_ticks():
  158.         t.tick1On = False
  159.         t.tick2On = False
  160.     for t in ax.yaxis.get_major_ticks():
  161.         t.tick1On = False
  162.         t.tick2On = False
  163.  
  164.     #face colors list
  165.     quadmesh = ax.findobj(QuadMesh)[0]
  166.     facecolors = quadmesh.get_facecolors()
  167.  
  168.     #iter in text elements
  169.     array_df = np.array( df_cm.to_records(index=False).tolist() )
  170.     text_add = []; text_del = [];
  171.     posi = -1 #from left to right, bottom to top.
  172.     for t in ax.collections[0].axes.texts: #ax.texts:
  173.         pos = np.array( t.get_position()) - [0.5,0.5]
  174.         lin = int(pos[1]); col = int(pos[0]);
  175.         posi += 1
  176.         #print ('>>> pos: %s, posi: %s, val: %s, txt: %s' %(pos, posi, array_df[lin][col], t.get_text()))
  177.  
  178.         #set text
  179.         txt_res = configcell_text_and_colors(array_df, lin, col, t, facecolors, posi, fz, fmt, show_null_values)
  180.  
  181.         text_add.extend(txt_res[0])
  182.         text_del.extend(txt_res[1])
  183.  
  184.     #remove the old ones
  185.     for item in text_del:
  186.         item.remove()
  187.     #append the new ones
  188.     for item in text_add:
  189.         ax.text(item['x'], item['y'], item['text'], **item['kw'])
  190.  
  191.     #titles and legends
  192.     ax.set_title('Confusion matrix')
  193.     ax.set_xlabel('Predicted')
  194.     ax.set_ylabel('Actual')
  195.     plt.tight_layout()  #set layout slim
  196.     plt.show()
  197. #
  198.  
  199. def plot_confusion_matrix_from_data(y_test, predictions, columns=None, annot=True,
  200.       cmap="Oranges", fmt='.2f', fz=11, lw=0.5, cbar=False, figsize=[8,8], show_null_values=0):
  201.     """
  202.        plot confusion matrix function with y_test (actual values) and predictions (predic),
  203.        whitout a confusion matrix yet
  204.    """
  205.     from sklearn.metrics import confusion_matrix
  206.     from pandas import DataFrame
  207.  
  208.     #data
  209.     if(not columns):
  210.         columns = range(1, len(np.unique(y_test))+1)
  211.  
  212.     confm = confusion_matrix(y_test, predictions)
  213.     cmap = 'Oranges';
  214.     fz = 11;
  215.     figsize=[9,9];
  216.     show_null_values = 2
  217.     df_cm = DataFrame(confm, index=columns, columns=columns)
  218.     pretty_plot_confusion_matrix(df_cm, fz=fz, cmap=cmap, figsize=figsize, show_null_values=show_null_values)
  219. #
  220.  
  221.  
  222.  
  223. #
  224. #Test functions
  225. #
  226. def _test_cm():
  227.     #test function with confusion matrix done
  228.     array = np.array( [[13,  0,  1,  0,  2,  0],
  229.                        [ 0, 50,  2,  0, 10,  0],
  230.                        [ 0, 13, 16,  0,  0,  3],
  231.                        [ 0,  0,  0, 13,  1,  0],
  232.                        [ 0, 40,  0,  1, 15,  0],
  233.                        [ 0,  0,  0,  0,  0, 20]])
  234.     #get pandas dataframe
  235.     df_cm = DataFrame(array, index=range(1,7), columns=range(1,7))
  236.     #colormap: see this and choose your more dear
  237.     cmap = 'PuRd'
  238.     pretty_plot_confusion_matrix(df_cm, cmap=cmap)
  239. #
  240.  
  241. def _test_data_class():
  242.     """ test function with y_test (actual values) and predictions (predic) """
  243.     #data
  244.     y_test = np.array([1,2,3,4,5, 1,2,3,4,5, 1,2,3,4,5, 1,2,3,4,5, 1,2,3,4,5, 1,2,3,4,5, 1,2,3,4,5, 1,2,3,4,5, 1,2,3,4,5, 1,2,3,4,5, 1,2,3,4,5, 1,2,3,4,5, 1,2,3,4,5, 1,2,3,4,5, 1,2,3,4,5, 1,2,3,4,5, 1,2,3,4,5, 1,2,3,4,5, 1,2,3,4,5])
  245.     predic = np.array([1,2,4,3,5, 1,2,4,3,5, 1,2,3,4,4, 1,4,3,4,5, 1,2,4,4,5, 1,2,4,4,5, 1,2,4,4,5, 1,2,4,4,5, 1,2,3,3,5, 1,2,3,3,5, 1,2,3,4,4, 1,2,3,4,1, 1,2,3,4,1, 1,2,3,4,1, 1,2,4,4,5, 1,2,4,4,5, 1,2,4,4,5, 1,2,4,4,5, 1,2,3,4,5])
  246.     columns = []
  247.     annot = True;
  248.     cmap = 'Oranges';
  249.     fmt = '.2f'
  250.     lw = 0.5
  251.     cbar = False
  252.     show_null_values = 2
  253.     #size::
  254.     fz = 12;
  255.     figsize = [9,9];
  256.     if(len(y_test) > 10):
  257.         fz=9; figsize=[14,14];
  258.     plot_confusion_matrix_from_data(y_test, predic, columns,
  259.       annot, cmap, fmt, fz, lw, cbar, figsize, show_null_values)
  260. #
  261.  
  262.  
  263. #
  264. #MAIN
  265. #
  266. if(__name__ == '__main__'):
  267.     print('__main__')
  268.     print('_test_cm: test function with confusion matrix done\nand pause')
  269.     _test_cm()
  270.     plt.pause(5)
  271.     print('_test_data_class: test function with y_test (actual values) and predictions (predic)')
  272.     _test_data_class()
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement