Advertisement
Miki19xs

Untitled

May 19th, 2017
69
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
Python 4.41 KB | None | 0 0
  1. # -*- coding: utf-8 -*-
  2. """
  3. Created on Thu May 18 16:27:51 2017
  4.  
  5. @author: Michele
  6. """
  7.  
  8. # -*- coding: utf-8 -*-
  9. """
  10. Created on Mon May  1 16:43:05 2017
  11. """
  12.  
  13.  
  14. import time
  15. import csv
  16. import numpy as np
  17.  
  18. from numpy import genfromtxt
  19. from matplotlib import pyplot as plt
  20.  
  21. from sklearn.neural_network import MLPClassifier
  22. from sklearn.metrics import accuracy_score
  23.  
  24.  
  25. def DrawDigit(A, label=''):
  26.     """ Draw single digit as a greyscale matrix"""
  27.     fig = plt.figure(figsize=(6,6))
  28.     # Uso la colormap 'gray' per avere la schacchiera in bianco&nero
  29.     img = plt.imshow(A, cmap='gray_r')
  30.     plt.xlabel(label)
  31.     plt.show()
  32.  
  33.    
  34. def ElaborateTrainingSet(data):
  35.     """ Elaborate training set """
  36.     X = []
  37.     Y = []    
  38.     for row in data:
  39.         X.append(np.array(row[1:]))
  40.         Y.append(int(row[0]))        
  41.     return X, Y
  42.  
  43.  
  44.  
  45. def ElaborateTestSet(data):
  46.     """ Elaborate test set """
  47.     X = []
  48.     for row in data:
  49.         X.append(np.array(row))
  50.  
  51.     return X
  52.  
  53.  
  54. def LearnANN(data):
  55.     """ Learn an Artificial Neural Network and return the corresponding object """
  56.     x_train, y_train = ElaborateTrainingSet(data)    
  57.    
  58.     # PRIMA DI FARE QUESTO ESERCIZIO, STUDIARE IL TUTORIAL:
  59.     # http://scikit-learn.org/stable/modules/neural_networks_supervised.html
  60.     #
  61.     # DA COMPLETARE: PROVARE I DIVERSI PARAMETRI DI QUESTA CLASSE
  62.     # http://scikit-learn.org/stable/modules/generated/sklearn.neural_network.MLPClassifier.html
  63.     ann = MLPClassifier(hidden_layer_sizes=(100, ), activation='tanh', solver='adam', random_state=1)
  64.     # COME VOLETE PROGETTARE LA VOSTRA RETE MULTILIVELLO???
  65.    
  66.     ann.fit(x_train, y_train)
  67.    
  68.     # ESERCIZIO 2: INVECE DI USRARE LA LIBRERIA SCIKIT, IMPLEMENTARE UNA RETE
  69.     #              NEURALE BASANDOSI SULL'ESEMPIO VISTO AL SEMINARIO DEL 4 maggio 2017
  70.     return ann
  71.  
  72.  
  73. def TestANN(ann, x_test, y_test):
  74.     """ Test the learned ANN on the given set of data """
  75.     y_pred = ann.predict(x_test)
  76.            
  77.     print("Accuracy: ", accuracy_score(y_test, y_pred), ' - Number of itertions:', ann.n_iter_)
  78.    
  79.     # Write the predictinos in a .csv file
  80.     with open('solution.csv','w') as csv_file:
  81.         writer = csv.writer(csv_file, delimiter=',', lineterminator='\n')
  82.         writer.writerow(['ImageId','Label'])
  83.         for i,p in enumerate(y_pred):
  84.             writer.writerow([i+1,p])
  85.  
  86.  
  87. def EvaluateANN(ann, x_test):
  88.     """ Test the learned ANN and produce output for Kaggle """
  89.     start = time.time()
  90.    
  91.     y_pred = ann.predict(x_test)
  92.    
  93.     print('Evaluation time:', time.time()-start,'- size:', len(my_test))        
  94.     print('Number of itertions:', ann.n_iter_)
  95.    
  96.     # Write the predictinos in a .csv file
  97.     with open('solution.csv','w') as csv_file:
  98.         writer = csv.writer(csv_file, delimiter=',', lineterminator='\n')
  99.         writer.writerow(['ImageId','Label'])
  100.         for i,p in enumerate(y_pred):
  101.             writer.writerow([i+1,p])
  102.    
  103.  
  104. #------------------------------------------
  105. #              MAIN ENTRY POINT
  106. #------------------------------------------
  107. if __name__ == "__main__":
  108.     # Misura il tempo per le operazioni principali
  109.     start = time.time()
  110.    
  111.     # Fase 1: Training
  112.     # Read CSV from Numpy, Link:
  113.     # https://docs.scipy.org/doc/numpy/reference/generated/numpy.genfromtxt.html
  114.     my_data = genfromtxt('C:/Users/Michele/Desktop/Programmazione2-master/Assignments/hw3/minst_test_small.csv', delimiter=',', skip_header=1)            
  115.     print('Reading time:', time.time()-start)
  116.     start = time.time()
  117.  
  118.     # Cambia in True per plottare alcune immagine
  119.     if False:
  120.         for row in my_data[11:19]:
  121.             # Documentation for function 'reshape':
  122.             # https://docs.scipy.org/doc/numpy/reference/generated/numpy.reshape.html
  123.             A = np.array(row[1:]).reshape(28,28)        
  124.             DrawDigit(A, 'Digit: ' + str(int(row[0])))
  125.  
  126.     ann = LearnANN(my_data)
  127.    
  128.     print('Learning time:', time.time()-start, '- size:', len(my_data))
  129.    
  130.     # Fase 2: local test for learning of parameters
  131.     # DA COMPLETARE TORVARE I VOSTRI PARAMETRI NEL MODO CHE PREFERITE
  132.    
  133.     # Fase 3: Evaluate on Kaggle test set
  134.     my_test = genfromtxt('C:/Users/Michele/Desktop/Programmazione2-master/Assignments/hw3/minst_train_small.csv', delimiter=',', skip_header=1)
  135.     x_test, y_test = ElaborateTrainingSet(my_test)
  136.     TestANN(ann, x_test,y_test)
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement