Advertisement
4ever_bored

theano_main.py

Oct 14th, 2016
161
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
Python 6.17 KB | None | 0 0
  1. # -*- coding: utf-8 -*-
  2. """
  3. Created on Fri Oct  7 15:44:38 2016
  4.  
  5. @author: anpa
  6. """
  7.  
  8. import os
  9. import sys
  10. import timeit
  11. import signal
  12.  
  13. import numpy
  14. import matplotlib.pyplot as plt
  15.  
  16. import theano
  17. import theano.tensor as T
  18.  
  19. from dense import DenseLayer
  20. import updates as update_rules
  21. import lenet5
  22. import utils
  23. from utils import eprint
  24.  
  25. theano.config.blas.ldflags='-lopenblas'
  26. cxxflag1 = '-IC:\\OpenBLAS-v0.2.15-Win64-int32\\include '
  27. cxxflag2 = '-LC:\\OpenBLAS-v0.2.15-Win64-int32\\lib'
  28. theano.config.gcc.cxxflags= cxxflag1 + cxxflag2
  29.  
  30. # dataset root folder
  31. data_dir = ''
  32. # training set input file
  33. training_im_file = ''
  34. # training set labels file
  35. training_lb_file = ''
  36. # test set input file
  37. test_im_file = ''
  38. # test set labels file
  39. test_lb_file = ''
  40. # number of samples per batch
  41. batch_size = 1273
  42. # number of batches
  43. nbatch = 637
  44. # input image dimensions
  45. img_rows, img_cols = 20, 195
  46. # number of convolutional filters to use
  47. nkernels = [30, 30]
  48. # size of pooling area for max pooling
  49. pooling_size = (2, 2)
  50. # first convolution kernel size
  51. kernel_size = [(3, 16), (2, 5)]          
  52. # number of epochs
  53. nepochs = 2
  54.  
  55. SIGTERM_FLAG = False
  56.  
  57. def sigterm_handler(_signo, _stack_frame):
  58.     global SIGTERM_FLAG
  59.     print('SIGTERM Caught!')
  60.     SIGTERM_FLAG = True
  61.            
  62. def evaluate_lenet5():
  63.    
  64.     bg = utils.batch_generator(data_dir + training_im_file,
  65.                                data_dir + training_lb_file,
  66.                                (img_rows, img_cols), r_eff, nbatch, batch_size)
  67.     train_set_x, train_set_y = next(bg)
  68.    
  69.     valid_set_x, valid_set_y = utils.get_testset(data_dir + test_im_file,
  70.                                                  data_dir + test_lb_file,
  71.                                                  (img_rows, img_cols), r_eff)
  72.    
  73.     x = T.dtensor4('x')
  74.     y = T.dvector('y')  
  75.  
  76.     print('... building the model')
  77.  
  78.     layer0 = lenet5.LeNetConvPoolLayer(
  79.         input=x,
  80.         image_shape=(None, 1, img_rows, img_cols),
  81.         filter_shape=(nkernels[0], 1, kernel_size[0][0], kernel_size[0][1]),
  82.         poolsize=(2, 2)
  83.     )
  84.  
  85.     layer1 = lenet5.LeNetConvPoolLayer(
  86.         input=layer0.output,
  87.         image_shape=(None, nkernels[0], layer0.dim0, layer0.dim1),
  88.         filter_shape=(nkernels[1], nkernels[0],
  89.                       kernel_size[1][0], kernel_size[1][1]),
  90.         poolsize=(2, 2)
  91.     )
  92.  
  93.     layer2_input = layer1.output.flatten(2)
  94.  
  95.     # construct a fully-connected sigmoidal layer
  96.     layer2 = DenseLayer(
  97.         input=layer2_input,
  98.         n_in=nkernels[1] * layer1.dim0 * layer1.dim1,
  99.         n_out=20,
  100.         activation=T.tanh
  101.     )
  102.    
  103.     layer3 = DenseLayer(
  104.         input=layer2.output,
  105.         n_in=20,
  106.         n_out=15,
  107.         activation=T.tanh
  108.     )
  109.  
  110.     layer4 = DenseLayer(input=layer3.output, rng=rng, n_in=15, n_out=1,
  111.                         activation=None)
  112.  
  113.     cost = layer4.mse(y)
  114.  
  115.     params = layer4.params + layer3.params + layer2.params + layer1.params + layer0.params
  116.  
  117.     grads = T.grad(cost, params)
  118.    
  119.     sgd = update_rules.SGD(lr=learning_rate, momentum=0.9)
  120.     Adam = update_rules.Adam()
  121.     updates = Adam.get_updates(params, cost, grads)
  122.    
  123.     validate_model = theano.function(
  124.         [],
  125.         cost,
  126.         givens={
  127.             x: valid_set_x,
  128.             y: valid_set_y
  129.         }
  130.     )
  131.  
  132.     train_model = theano.function(
  133.         [],
  134.         cost,
  135.         updates=updates,
  136.         givens={
  137.             x: train_set_x,
  138.             y: train_set_y
  139.         }
  140.     )
  141.    
  142.     preds = theano.function(inputs=[],
  143.         outputs=layer4.y_pred,
  144.         givens={
  145.             x: valid_set_x,
  146.         }
  147.     )
  148.  
  149.     print('... training')
  150.  
  151.     validation_frequency = 100
  152.  
  153.     best_validation_loss = numpy.inf
  154.     start_time = timeit.default_timer()
  155.  
  156.     epoch = 0
  157.     train_cost = []
  158.     test_cost = []
  159.     signal.signal(signal.SIGINT, sigterm_handler)
  160.  
  161.     while (epoch < nepochs):
  162.         epoch = epoch + 1
  163.         for minibatch_index in range(nbatch):
  164.    
  165.             iter = (epoch - 1) * nbatch + minibatch_index
  166.                
  167.             if SIGTERM_FLAG is True:
  168.                 sys.exit(1)
  169.                    
  170.             cost_ij = train_model()
  171.             train_cost.append(cost_ij)
  172.             if iter % 100 == 0:
  173.                 print('training @ iter = ', iter)
  174.                 print('training error %f'%cost_ij)
  175.                 for i,p in enumerate(params):
  176.                     curr_p = p.get_value(borrow=True)
  177.                     if len(curr_p.shape) > 2:
  178.                         curr_p = curr_p.flatten()                        
  179.                     numpy.savetxt('parameters'+str(i), curr_p)                    
  180.    
  181.             if (iter + 1) % validation_frequency == 0:
  182.    
  183.                 validation_losses = validate_model()
  184.                 this_validation_loss = numpy.mean(validation_losses)
  185.                 test_cost.append(this_validation_loss)
  186.                 print('epoch %i, minibatch %i/%i, validation error %f' %
  187.                       (epoch, minibatch_index + 1, nbatch,
  188.                        this_validation_loss))
  189.  
  190.     end_time = timeit.default_timer()
  191.     print('Optimization complete.')
  192.     for i,p in enumerate(params):
  193.         curr_p = p.get_value(borrow=True)
  194.         if len(curr_p.shape) > 2:
  195.            curr_p = curr_p.flatten()                        
  196.         numpy.savetxt('./parameters'+str(i), curr_p)      
  197.     numpy.savetxt('./training_loss', train_cost)
  198.     numpy.savetxt('./test_loss', test_cost)
  199.  
  200.     print('Best validation score of %f obtained at iteration %i, '
  201.               'with test performance %f' %
  202.              (best_validation_loss, best_iter + 1, test_score))
  203.     print(('The code for file ' +
  204.             os.path.split(__file__)[1] +
  205.             ' ran for %.2fm' % ((end_time - start_time) / 3600.)),
  206.           file=sys.stderr)
  207.     pds = preds()
  208.     plt.figure('Vanilla')
  209.     plt.scatter(test_set_y, pds)
  210.     plt.plot([numpy.amin(test_set_y),
  211.              numpy.amax(test_set_y)],
  212.              [numpy.amin(test_set_y),
  213.              numpy.amax(test_set_y)])
  214.     plt.show()
  215.     return    
  216.    
  217.  
  218. if __name__ == '__main__':
  219.     evaluate_lenet5()
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement