Advertisement
Not a member of Pastebin yet?
Sign Up,
it unlocks many cool features!
- import func.nn.backprop.BackPropagationNetwork;
- import func.nn.backprop.BackPropagationNetworkFactory;
- import opt.OptimizationAlgorithm;
- import opt.example.NeuralNetworkOptimizationProblem;
- import opt.ga.StandardGeneticAlgorithm;
- import shared.DataSet;
- import shared.Instance;
- import shared.SumOfSquaresError;
- import shared.filt.TestTrainSplitFilter;
- import shared.reader.ArffDataSetReader;
- public class RandomizedOptimization
- {
- private static final double ITERATIONS = 1000;
- public static void main(String[] args) {
- // Load instances
- Instance[] data = getData("/Documents/Georgia Tech/Spring 2019/cs4641/Datasets/phishing-websites.arff");
- if(data == null) {
- System.out.println("Error loading data.");
- return;
- }
- DataSet dataSet = new DataSet(data);
- SumOfSquaresError errorFunction = new SumOfSquaresError();
- TestTrainSplitFilter filter = new TestTrainSplitFilter(70);
- filter.filter(dataSet);
- BackPropagationNetworkFactory nnFactory = new BackPropagationNetworkFactory();
- BackPropagationNetwork neuralNetwork = nnFactory.createClassificationNetwork(new int[] {data[0].getData().size(), 1, 1});
- NeuralNetworkOptimizationProblem weightProblem = new NeuralNetworkOptimizationProblem(filter.getTrainingSet(), neuralNetwork, errorFunction);
- //RandomizedHillClimbing rhc = new RandomizedHillClimbing(weightProblem);
- //SimulatedAnnealing sa = new SimulatedAnnealing(1.0E5, 0.25, weightProblem);
- StandardGeneticAlgorithm ga = new StandardGeneticAlgorithm(300, 72, 54, weightProblem);
- startRecording("Curve");
- run(dataSet, neuralNetwork, ga, filter);
- stopRecording();
- //gridSearch(dataSet, neuralNetwork, weightProblem, filter);
- }
- private static void gridSearch(DataSet dataSet, BackPropagationNetwork neuralNetwork, NeuralNetworkOptimizationProblem problem, TestTrainSplitFilter filter) {
- Instance[] testSet = filter.getTestingSet().getInstances();
- int[] populations = {100, 200, 300, 400, 500};
- for(int population : populations) {
- StandardGeneticAlgorithm algorithm = new StandardGeneticAlgorithm(population, (int)(0.24*population), (int)(0.18*population), problem);
- for(int i = 0; i < ITERATIONS; i++) {
- int correctTest = 0;
- algorithm.train();
- for(int j = 0; j < testSet.length; j++) {
- neuralNetwork.setInputValues(testSet[j].getData());
- neuralNetwork.run(); // feed forward
- double expected = Double.parseDouble(testSet[j].getLabel().toString());
- double out = Double.parseDouble(neuralNetwork.getOutputValues().toString());
- if(Math.abs(expected-out) < 0.5D) {
- correctTest++;
- }
- }
- if(i == ITERATIONS-1) {
- System.out.println("Pop: " + population + " = " + ((double)correctTest/testSet.length));
- }
- }
- }
- }
- private static void run(DataSet dataSet, BackPropagationNetwork neuralNetwork, OptimizationAlgorithm algorithm, TestTrainSplitFilter filter) {
- Instance[] trainSet = filter.getTrainingSet().getInstances();
- Instance[] testSet = filter.getTestingSet().getInstances();
- for(int i = 0; i < ITERATIONS; i++) {
- int correctTrain = 0, correctTest = 0;
- algorithm.train();
- for(int j = 0; j < trainSet.length; j++) {
- neuralNetwork.setInputValues(trainSet[j].getData());
- neuralNetwork.run(); // feed forward
- double expected = Double.parseDouble(trainSet[j].getLabel().toString());
- double out = Double.parseDouble(neuralNetwork.getOutputValues().toString());
- if(Math.abs(expected-out) < 0.5D) {
- correctTrain++;
- }
- }
- for(int j = 0; j < testSet.length; j++) {
- neuralNetwork.setInputValues(testSet[j].getData());
- neuralNetwork.run(); // feed forward
- double expected = Double.parseDouble(testSet[j].getLabel().toString());
- double out = Double.parseDouble(neuralNetwork.getOutputValues().toString());
- if(Math.abs(expected-out) < 0.5D) {
- correctTest++;
- }
- }
- System.out.println(i + "," + ((double)correctTrain/trainSet.length) + "," + ((double)correctTest/testSet.length));
- }
- }
- private static Instance[] getData(String s) {
- try {
- ArffDataSetReader reader = new ArffDataSetReader(getHomeDirectory() + s);
- Instance[] instances = reader.read().getInstances();
- for(int i = 0; i < instances.length; i++) {
- double[] in = new double[instances[i].getData().size()-1];
- double out = instances[i].getData().get(instances[i].getData().size()-1);
- for(int j = 0; j < instances[i].getData().size()-1; j++) {
- in[j] = instances[i].getData().get(j);
- }
- instances[i] = new Instance(in);
- instances[i].setLabel(new Instance(out == 1 ? 1 : 0));
- }
- return instances;
- } catch(Exception e) {
- e.printStackTrace();
- return null;
- }
- }
- private static String getHomeDirectory() {
- return System.getProperty("user.home");
- }
- private static long timestamp = 0;
- public static void startRecording(String s) {
- timestamp = System.currentTimeMillis();
- System.out.println("Recording time for: " + s);
- }
- public static void stopRecording() {
- long diff = System.currentTimeMillis()-timestamp;
- System.out.println("Time elapsed: " + diff);
- }
- }
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement