Advertisement
Guest User

reportCallback

a guest
Mar 4th, 2018
107
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
Python 2.06 KB | None | 0 0
  1. from keras.callbacks import Callback
  2. import time
  3. import matplotlib.pyplot as plt # conda install matplotlib
  4. import matplotlib.lines as mlines
  5.  
  6. class ReportCallback(Callback):
  7.     def __init__(self, filename, xTrain, yTrain):
  8.         self.filename = filename   
  9.         self.report_file = open(filename+".csv", "w+")
  10.         self.report_file.write("#epoch,time,loss,acc,val_loss,val_acc\n");
  11.         self.allLoss = []
  12.         self.allAcc = []
  13.         self.allValLoss = []
  14.         self.allValAcc = []
  15.         self.xTrain = xTrain
  16.         self.yTrain = yTrain
  17.    
  18.     def on_train_begin(self, logs={}):
  19.         self.beginTime = time.time()
  20.  
  21.     def on_epoch_end(self, epoch, logs={}):    
  22.         self.allLoss.append(logs.get('loss'))
  23.         self.allAcc.append(logs.get('acc'))
  24.         self.allValLoss.append(logs.get('val_loss'))
  25.         self.allValAcc.append(logs.get('val_acc'))
  26.        
  27.         self.report_file.write(
  28.                         str(epoch+1)+","+
  29.                         str(time.time()-self.beginTime)+","+
  30.                         str(logs.get('loss'))+","+
  31.                         str(logs.get('acc'))+","+
  32.                         str(logs.get('val_loss'))+","+
  33.                         str(logs.get('val_acc'))+"\n"
  34.                         )
  35.  
  36.         if epoch+1 == self.params['nb_epoch']:
  37.             self.report_file.close()
  38.            
  39.             x = range(1, epoch+2)
  40.            
  41.             fig, ax1 = plt.subplots()
  42.             ax1.set_ylabel("Loss", color="#CC2529")
  43.             ax1.plot(x, self.allLoss, "--", color="#CC2529", linewidth=2.0, label="Train Loss")
  44.             ax1.plot(x, self.allValLoss, color="#CC2529", linewidth=2.0, label="Test Loss")
  45.             ax1.axis('tight')
  46.            
  47.             ax2 = ax1.twinx()
  48.             ax2.set_ylabel("Accuracy", color="#396BB1")
  49.             ax2.plot(x, self.allAcc, "--", color="#396BB1", linewidth=2.0, label="Train Acc")
  50.             ax2.plot(x, self.allValAcc, color="#396BB1", linewidth=2.0, label="Test Acc")
  51.             ax2.axis('tight')
  52.            
  53.             ax1.set_xlabel("Epochs")
  54.             training_line  = mlines.Line2D([], [], linestyle='--', linewidth=2.0, color="black", label="Training")
  55.             test_line  = mlines.Line2D([], [], linewidth=2.0, color="black", label="Test")
  56.             plt.legend(bbox_to_anchor=(0., 1.02, 1., .102), loc=3, ncol=2, mode="expand", borderaxespad=0., handles=[training_line, test_line])
  57.            
  58.             plt.savefig(self.filename+"_graph.png")
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement