Advertisement
Guest User

Untitled

a guest
Feb 23rd, 2019
69
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
text 6.01 KB | None | 0 0
  1. import func.nn.backprop.BackPropagationNetwork;
  2. import func.nn.backprop.BackPropagationNetworkFactory;
  3. import opt.OptimizationAlgorithm;
  4. import opt.example.NeuralNetworkOptimizationProblem;
  5. import opt.ga.StandardGeneticAlgorithm;
  6. import shared.DataSet;
  7. import shared.Instance;
  8. import shared.SumOfSquaresError;
  9. import shared.filt.TestTrainSplitFilter;
  10. import shared.reader.ArffDataSetReader;
  11.  
  12. public class RandomizedOptimization
  13. {
  14. private static final double ITERATIONS = 1000;
  15.  
  16. public static void main(String[] args) {
  17. // Load instances
  18. Instance[] data = getData("/Documents/Georgia Tech/Spring 2019/cs4641/Datasets/phishing-websites.arff");
  19. if(data == null) {
  20. System.out.println("Error loading data.");
  21. return;
  22. }
  23.  
  24. DataSet dataSet = new DataSet(data);
  25. SumOfSquaresError errorFunction = new SumOfSquaresError();
  26. TestTrainSplitFilter filter = new TestTrainSplitFilter(70);
  27. filter.filter(dataSet);
  28.  
  29. BackPropagationNetworkFactory nnFactory = new BackPropagationNetworkFactory();
  30. BackPropagationNetwork neuralNetwork = nnFactory.createClassificationNetwork(new int[] {data[0].getData().size(), 1, 1});
  31. NeuralNetworkOptimizationProblem weightProblem = new NeuralNetworkOptimizationProblem(filter.getTrainingSet(), neuralNetwork, errorFunction);
  32.  
  33. //RandomizedHillClimbing rhc = new RandomizedHillClimbing(weightProblem);
  34. //SimulatedAnnealing sa = new SimulatedAnnealing(1.0E5, 0.25, weightProblem);
  35. StandardGeneticAlgorithm ga = new StandardGeneticAlgorithm(300, 72, 54, weightProblem);
  36. startRecording("Curve");
  37. run(dataSet, neuralNetwork, ga, filter);
  38. stopRecording();
  39.  
  40. //gridSearch(dataSet, neuralNetwork, weightProblem, filter);
  41. }
  42.  
  43. private static void gridSearch(DataSet dataSet, BackPropagationNetwork neuralNetwork, NeuralNetworkOptimizationProblem problem, TestTrainSplitFilter filter) {
  44. Instance[] testSet = filter.getTestingSet().getInstances();
  45.  
  46. int[] populations = {100, 200, 300, 400, 500};
  47.  
  48. for(int population : populations) {
  49. StandardGeneticAlgorithm algorithm = new StandardGeneticAlgorithm(population, (int)(0.24*population), (int)(0.18*population), problem);
  50.  
  51. for(int i = 0; i < ITERATIONS; i++) {
  52. int correctTest = 0;
  53.  
  54. algorithm.train();
  55.  
  56. for(int j = 0; j < testSet.length; j++) {
  57. neuralNetwork.setInputValues(testSet[j].getData());
  58. neuralNetwork.run(); // feed forward
  59. double expected = Double.parseDouble(testSet[j].getLabel().toString());
  60. double out = Double.parseDouble(neuralNetwork.getOutputValues().toString());
  61. if(Math.abs(expected-out) < 0.5D) {
  62. correctTest++;
  63. }
  64. }
  65.  
  66. if(i == ITERATIONS-1) {
  67. System.out.println("Pop: " + population + " = " + ((double)correctTest/testSet.length));
  68. }
  69. }
  70. }
  71. }
  72.  
  73. private static void run(DataSet dataSet, BackPropagationNetwork neuralNetwork, OptimizationAlgorithm algorithm, TestTrainSplitFilter filter) {
  74. Instance[] trainSet = filter.getTrainingSet().getInstances();
  75. Instance[] testSet = filter.getTestingSet().getInstances();
  76.  
  77. for(int i = 0; i < ITERATIONS; i++) {
  78. int correctTrain = 0, correctTest = 0;
  79.  
  80. algorithm.train();
  81.  
  82. for(int j = 0; j < trainSet.length; j++) {
  83. neuralNetwork.setInputValues(trainSet[j].getData());
  84. neuralNetwork.run(); // feed forward
  85. double expected = Double.parseDouble(trainSet[j].getLabel().toString());
  86. double out = Double.parseDouble(neuralNetwork.getOutputValues().toString());
  87. if(Math.abs(expected-out) < 0.5D) {
  88. correctTrain++;
  89. }
  90. }
  91.  
  92. for(int j = 0; j < testSet.length; j++) {
  93. neuralNetwork.setInputValues(testSet[j].getData());
  94. neuralNetwork.run(); // feed forward
  95. double expected = Double.parseDouble(testSet[j].getLabel().toString());
  96. double out = Double.parseDouble(neuralNetwork.getOutputValues().toString());
  97. if(Math.abs(expected-out) < 0.5D) {
  98. correctTest++;
  99. }
  100. }
  101.  
  102. System.out.println(i + "," + ((double)correctTrain/trainSet.length) + "," + ((double)correctTest/testSet.length));
  103. }
  104. }
  105.  
  106. private static Instance[] getData(String s) {
  107. try {
  108. ArffDataSetReader reader = new ArffDataSetReader(getHomeDirectory() + s);
  109. Instance[] instances = reader.read().getInstances();
  110. for(int i = 0; i < instances.length; i++) {
  111. double[] in = new double[instances[i].getData().size()-1];
  112. double out = instances[i].getData().get(instances[i].getData().size()-1);
  113. for(int j = 0; j < instances[i].getData().size()-1; j++) {
  114. in[j] = instances[i].getData().get(j);
  115. }
  116. instances[i] = new Instance(in);
  117. instances[i].setLabel(new Instance(out == 1 ? 1 : 0));
  118. }
  119. return instances;
  120. } catch(Exception e) {
  121. e.printStackTrace();
  122. return null;
  123. }
  124. }
  125.  
  126. private static String getHomeDirectory() {
  127. return System.getProperty("user.home");
  128. }
  129.  
  130. private static long timestamp = 0;
  131.  
  132. public static void startRecording(String s) {
  133. timestamp = System.currentTimeMillis();
  134. System.out.println("Recording time for: " + s);
  135. }
  136.  
  137. public static void stopRecording() {
  138. long diff = System.currentTimeMillis()-timestamp;
  139. System.out.println("Time elapsed: " + diff);
  140. }
  141. }
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement