Advertisement
Not a member of Pastebin yet?
Sign Up,
it unlocks many cool features!
- from keras.callbacks import Callback
- import time
- import matplotlib.pyplot as plt # conda install matplotlib
- import matplotlib.lines as mlines
- class ReportCallback(Callback):
- def __init__(self, filename, xTrain, yTrain):
- self.filename = filename
- self.report_file = open(filename+".csv", "w+")
- self.report_file.write("#epoch,time,loss,acc,val_loss,val_acc\n");
- self.allLoss = []
- self.allAcc = []
- self.allValLoss = []
- self.allValAcc = []
- self.xTrain = xTrain
- self.yTrain = yTrain
- def on_train_begin(self, logs={}):
- self.beginTime = time.time()
- def on_epoch_end(self, epoch, logs={}):
- self.allLoss.append(logs.get('loss'))
- self.allAcc.append(logs.get('acc'))
- self.allValLoss.append(logs.get('val_loss'))
- self.allValAcc.append(logs.get('val_acc'))
- self.report_file.write(
- str(epoch+1)+","+
- str(time.time()-self.beginTime)+","+
- str(logs.get('loss'))+","+
- str(logs.get('acc'))+","+
- str(logs.get('val_loss'))+","+
- str(logs.get('val_acc'))+"\n"
- )
- if epoch+1 == self.params['nb_epoch']:
- self.report_file.close()
- x = range(1, epoch+2)
- fig, ax1 = plt.subplots()
- ax1.set_ylabel("Loss", color="#CC2529")
- ax1.plot(x, self.allLoss, "--", color="#CC2529", linewidth=2.0, label="Train Loss")
- ax1.plot(x, self.allValLoss, color="#CC2529", linewidth=2.0, label="Test Loss")
- ax1.axis('tight')
- ax2 = ax1.twinx()
- ax2.set_ylabel("Accuracy", color="#396BB1")
- ax2.plot(x, self.allAcc, "--", color="#396BB1", linewidth=2.0, label="Train Acc")
- ax2.plot(x, self.allValAcc, color="#396BB1", linewidth=2.0, label="Test Acc")
- ax2.axis('tight')
- ax1.set_xlabel("Epochs")
- training_line = mlines.Line2D([], [], linestyle='--', linewidth=2.0, color="black", label="Training")
- test_line = mlines.Line2D([], [], linewidth=2.0, color="black", label="Test")
- plt.legend(bbox_to_anchor=(0., 1.02, 1., .102), loc=3, ncol=2, mode="expand", borderaxespad=0., handles=[training_line, test_line])
- plt.savefig(self.filename+"_graph.png")
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement