Advertisement
Guest User

Untitled

a guest
Oct 23rd, 2017
63
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
text 1.99 KB | None | 0 0
  1. from ai.FileIO import FileIO
  2. from ai.NeuralNet import NeutalNet
  3. from ai.visualization import drawScatterPlot2, drawFunctionPlot
  4. import numpy;
  5.  
  6. def doClassification(trainFile, testFile, testHasAns=True):
  7.  
  8. trainData = FileIO.readFile2(trainFile)
  9. testData = FileIO.readFile2(testFile, train=False, testHasAns=testHasAns)
  10.  
  11. res_list = [x[-1] for x in trainData]
  12. unique = numpy.unique(res_list)
  13. if testHasAns:
  14. evalData = testData
  15. else:
  16. evalData = None
  17.  
  18. print('Learning...')
  19. nn = NeutalNet([2, 10,10, len(unique)], 'classification', 'sigmoid')
  20. nn.SGD(trainData, 30, 10, 0.5,evalData)
  21. print('Testing...')
  22. nnGuesses = []
  23. for x_t in testData:
  24. y_t = nn.feedforward(x_t)
  25. nnGuesses.append((x_t, y_t))
  26.  
  27. if not testHasAns: # test does not have ans so we guess output and save to file
  28. FileIO.saveFile(testData, nn.problemType, testFile)
  29.  
  30. drawScatterPlot2(trainData, nnGuesses, nn.errors)
  31.  
  32.  
  33. def doRegression(trainFile, testFile, testHasAns):
  34.  
  35. trainData = FileIO.readFile2(trainFile)
  36. res_list = [x[-1] for x in trainData]
  37. maxValue = float(max(res_list))
  38. minValue = float(min(res_list))
  39. print('Learning...')
  40. nn = NeutalNet([1, 6, 1], 'regression', 'tanh',maxValue, minValue)
  41. nn.SGD(trainData, 50, 10, 0.1)
  42.  
  43. print('Testing...')
  44. listInput = FileIO.readFile2(testFile, train=False, testHasAns=testHasAns)
  45. testData = []
  46. for x_t in listInput:
  47. y_t = nn.feedforward(x_t)
  48. testData.append((x_t, y_t))
  49. drawFunctionPlot(trainData, testData, nn.errors)
  50. FileIO.saveFile(testData, nn.problemType)
  51.  
  52. def main():
  53. trainFileCls = './ai/data.train.csv'
  54. testFileCls = './ai/data.test.csv'
  55. doClassification(trainFileCls, testFileCls, testHasAns=False)
  56. trainFileReg = './ai/data.xsq.train.csv'
  57. testFileReg = './ai/data.xsq.test.csv'
  58. #doRegression(trainFileReg, testFileReg, testHasAns=False)
  59.  
  60.  
  61. if __name__ == '__main__':
  62. main()
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement