Advertisement
Guest User

Caffe train/val training/testing plot

a guest
Mar 13th, 2017
111
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
Python 9.31 KB | None | 0 0
  1. # In the name of GOD the most compassionate the most merciful
  2. # Originally developed by Yasse Souri
  3. # Just added the search for current directory so that users dont have to use command prompts anymore!
  4. # and also shows the top 4 accuracies achieved so far, and displaying the highest in the plot title
  5. # Coded By: Seyyed Hossein Hasan Pour (Coderx7@gmail.com)
  6. # -------How to Use ---------------
  7. # 1.Just place your caffe's traning/test log file (with .log extension) next to this script
  8. # and then run the script.If you have multiple logs placed next to the script, it will plot all of them
  9. # you may also copy this script to your working directory, where you generate/keep your train/test logs
  10. # and easily execute the script and see the curve plotted.
  11. # this script is standalone.
  12. # 2. you can use command line arguments as well, just feed the script with different log files separated by space
  13. # and you are good to go.
  14. #----------------------------------
  15. import numpy as np
  16. import re
  17. import click
  18. import glob, os
  19. from matplotlib import pylab as plt
  20. import operator
  21. import ntpath
  22. @click.command()
  23. @click.argument('files', nargs=-1, type=click.Path(exists=True))
  24. def main(files):
  25.     plt.style.use('ggplot')
  26.     fig, ax1 = plt.subplots()
  27.     ax2 = ax1.twinx()
  28.     ax1.set_xlabel('iteration')
  29.     ax1.set_ylabel('loss')
  30.     ax2.set_ylabel('accuracy %')
  31.     if not files:
  32.         print 'no args found'
  33.         print '\n\rloading all files with .log extension from current directory'
  34.         os.chdir(".")
  35.         files = glob.glob("*.log")
  36.  
  37.     for i, log_file in enumerate(files):
  38.         loss_iterations, losses, accuracy_iterations, accuracies, accuracies_iteration_checkpoints_ind, fileName = parse_log(log_file)
  39.         disp_results(fig, ax1, ax2, loss_iterations, losses, accuracy_iterations, accuracies, accuracies_iteration_checkpoints_ind, fileName, color_ind=i)
  40.        
  41.         loss_iterations, losses, accuracy_iterations, accuracies, accuracies_iteration_checkpoints_ind, fileName = parse_log2(log_file)
  42.         disp_results(fig, ax1, ax2, loss_iterations, losses, accuracy_iterations, accuracies, accuracies_iteration_checkpoints_ind, fileName, color_ind=i+1)
  43.        
  44.     plt.show()
  45.    
  46.  
  47. def parse_log2(log_file):
  48.     with open(log_file, 'r') as log_file2:
  49.         log = log_file2.read()
  50.  
  51.     #loss_pattern = r"Iteration (?P<iter_num>\d+), loss = (?P<loss_val>[+-]?(\d+(\.\d*)?|\.\d+)([eE][+-]?\d+)?)"
  52.     #loss_pattern2 = r"Batch (?P<iter_num>\d+), accuracy = (?P<accuracy>[+-]?(\d+(\.\d*)?|\.\d+)([eE][+-]?\d+)?)(\n.* = (?P<loss>[+-]?(\d+(\.\d*)?|\.\d+)([eE][+-]?\d+)?))";
  53.     losses = []
  54.     loss_iterations = []
  55.     loss_accuracy_pattern = r"Iteration (?P<iter_num>\d+) (.?)*, loss = ((\d*.\d*)+)\n.*( *)Train net output #0: accuracy_training = ((\d*.\d*)+)"
  56.     accuracies = []
  57.     accuracy_iterations = []
  58.     accuracies_iteration_checkpoints_ind = []
  59.      
  60.     fileName= 'train_'+os.path.basename(log_file)
  61.     #print re.search(loss_pattern2,log)    
  62.     # if re.search(loss_pattern2,log) != None:
  63.         # for r in re.findall(loss_pattern2, log):
  64.             # loss_iterations.append(int(r[0]))
  65.             ##print '\n'
  66.             ##print (r)
  67.             # losses.append(float(r[6]))
  68.             # iteration = int(r[0])
  69.             # accuracy = float(r[1]) * 100
  70.  
  71.             # if iteration % 10000 == 0 and iteration > 0:
  72.                 # accuracies_iteration_checkpoints_ind.append(len(accuracy_iterations))
  73.  
  74.             # accuracy_iterations.append(iteration)
  75.             # accuracies.append(accuracy)
  76.        
  77.    
  78.     if re.search(loss_accuracy_pattern,log) != None:
  79.         for r in re.findall(loss_accuracy_pattern, log):
  80.             #print '\n'
  81.             #print (r) 
  82.             iteration = int(r[0])
  83.             loss_iterations.append(iteration)
  84.             losses.append(float(r[3]))
  85.            
  86.             accuracy = float(r[6]) * 100
  87.            
  88.             if iteration % 10000 == 0 and iteration > 0:
  89.                 accuracies_iteration_checkpoints_ind.append(len(accuracy_iterations))
  90.  
  91.             accuracy_iterations.append(iteration)
  92.             accuracies.append(accuracy)
  93.            
  94.         # for r in re.findall(accuracy_pattern, log):
  95.             # iteration = int(r[0])
  96.             # accuracy = float(r[1]) * 100
  97.  
  98.             # if iteration % 10000 == 0 and iteration > 0:
  99.                 # accuracies_iteration_checkpoints_ind.append(len(accuracy_iterations))
  100.  
  101.             # accuracy_iterations.append(iteration)
  102.             # accuracies.append(accuracy)
  103.        
  104.     loss_iterations = np.array(loss_iterations)
  105.     losses = np.array(losses)
  106.  
  107.     accuracy_iterations = np.array(accuracy_iterations)
  108.     accuracies = np.array(accuracies)
  109.    
  110.     return loss_iterations, losses, accuracy_iterations, accuracies, accuracies_iteration_checkpoints_ind, fileName
  111.  
  112. def parse_log(log_file):
  113.     with open(log_file, 'r') as log_file2:
  114.         log = log_file2.read()
  115.  
  116.     #loss_pattern = r"Iteration (?P<iter_num>\d+), loss = (?P<loss_val>[+-]?(\d+(\.\d*)?|\.\d+)([eE][+-]?\d+)?)"
  117.     #loss_pattern2 = r"Batch (?P<iter_num>\d+), accuracy = (?P<accuracy>[+-]?(\d+(\.\d*)?|\.\d+)([eE][+-]?\d+)?)(\n.* = (?P<loss>[+-]?(\d+(\.\d*)?|\.\d+)([eE][+-]?\d+)?))";
  118.     losses = []
  119.     loss_iterations = []
  120.     loss_accuracy_pattern = r"Iteration (?P<iter_num>\d+), Testing net \(#0\)((.*?)(\n)*)* accuracy = ((\d*.\d*)+)((.*?)(\n)*)*(.?)*loss = ((\d*.\d*)+) \("
  121.     accuracies = []
  122.     accuracy_iterations = []
  123.     accuracies_iteration_checkpoints_ind = []
  124.      
  125.     fileName= 'test_'+os.path.basename(log_file)
  126.     # print re.search(loss_pattern2,log)    
  127.     # if re.search(loss_pattern2,log) != None:
  128.         # for r in re.findall(loss_pattern2, log):
  129.             # loss_iterations.append(int(r[0]))
  130.             # #print '\n'
  131.             # #print (r)   
  132.             # losses.append(float(r[6]))
  133.             # iteration = int(r[0])
  134.             # accuracy = float(r[1]) * 100
  135.  
  136.             # if iteration % 10000 == 0 and iteration > 0:
  137.                 # accuracies_iteration_checkpoints_ind.append(len(accuracy_iterations))
  138.  
  139.             # accuracy_iterations.append(iteration)
  140.             # accuracies.append(accuracy)
  141.        
  142.    
  143.     if re.search(loss_accuracy_pattern,log) != None:
  144.         for r in re.findall(loss_accuracy_pattern, log):
  145.             iteration = int(r[0])
  146.             loss_iterations.append(iteration)
  147.             losses.append(float(r[10]))
  148.             #print '\n'
  149.             #print (r) 
  150.            
  151.             accuracy = float(r[4]) * 100
  152.  
  153.             if iteration % 10000 == 0 and iteration > 0:
  154.                 accuracies_iteration_checkpoints_ind.append(len(accuracy_iterations))
  155.  
  156.             accuracy_iterations.append(iteration)
  157.             accuracies.append(accuracy)
  158.            
  159.         # for r in re.findall(accuracy_pattern, log):
  160.             # #print '\n'
  161.             # #print (r)   
  162.             # iteration = int(r[0])
  163.             # accuracy = float(r[5]) * 100
  164.  
  165.             # if iteration % 10000 == 0 and iteration > 0:
  166.                 # accuracies_iteration_checkpoints_ind.append(len(accuracy_iterations))
  167.  
  168.             # accuracy_iterations.append(iteration)
  169.             # accuracies.append(accuracy)
  170.        
  171.     loss_iterations = np.array(loss_iterations)
  172.     losses = np.array(losses)
  173.  
  174.     accuracy_iterations = np.array(accuracy_iterations)
  175.     accuracies = np.array(accuracies)
  176.    
  177.     return loss_iterations, losses, accuracy_iterations, accuracies, accuracies_iteration_checkpoints_ind, fileName
  178.  
  179.  
  180. def disp_results(fig, ax1, ax2, loss_iterations, losses, accuracy_iterations, accuracies, accuracies_iteration_checkpoints_ind, fileName, color_ind=0):
  181.     modula = len(plt.rcParams['axes.color_cycle'])
  182.     acrIterations =[]
  183.     top_acrs={}
  184.     if accuracies.size:
  185.         if  accuracies.size>4:
  186.             top_n = 4
  187.         else:
  188.             top_n = accuracies.size -1     
  189.         temp = np.argpartition(-accuracies, top_n)
  190.         result_indexces = temp[:top_n]
  191.         temp = np.partition(-accuracies, top_n)
  192.         result = -temp[:top_n]
  193.         for acr in result_indexces:
  194.             acrIterations.append(accuracy_iterations[acr])
  195.             top_acrs[str(accuracy_iterations[acr])]=str(accuracies[acr])
  196.  
  197.         sorted_top4 = sorted(top_acrs.items(), key=operator.itemgetter(1))
  198.         maxAcc = np.amax(accuracies, axis=0)
  199.         iterIndx = np.argmax(accuracies)
  200.         maxAccIter = accuracy_iterations[iterIndx]
  201.         maxIter =   accuracy_iterations[-1]
  202.         consoleInfo = format('\n[%s]:maximum accuracy [from 0 to %s ] = [Iteration %s]: %s ' %(fileName,maxIter,maxAccIter ,maxAcc))
  203.         plotTitle = format('max accuracy(%s) [Iteration %s]: %s ' % (fileName,maxAccIter, maxAcc))
  204.         print (consoleInfo)
  205.         #print (str(result))
  206.         #print(acrIterations)
  207.        # print 'Top 4 accuracies:'     
  208.         print ('Top 4 accuracies:'+str(sorted_top4))       
  209.         plt.title(plotTitle)
  210.     ax1.plot(loss_iterations, losses, color=plt.rcParams['axes.color_cycle'][(color_ind * 2 + 0) % modula])
  211.     ax2.plot(accuracy_iterations, accuracies, plt.rcParams['axes.color_cycle'][(color_ind * 2 + 1) % modula], label=str(fileName))
  212.     ax2.plot(accuracy_iterations[accuracies_iteration_checkpoints_ind], accuracies[accuracies_iteration_checkpoints_ind], 'o', color=plt.rcParams['axes.color_cycle'][(color_ind * 2 + 1) % modula])
  213.     plt.legend(loc='lower right')
  214.  
  215.  
  216. if __name__ == '__main__':
  217.     main()
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement