Advertisement
cromat1

Encog XOR string input

Dec 1st, 2015
658
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
Java 5.88 KB | None | 0 0
  1. package nrp.nrp;
  2.  
  3. import java.io.BufferedReader;
  4. import java.io.BufferedWriter;
  5. import java.io.File;
  6. import java.io.FileNotFoundException;
  7. import java.io.FileReader;
  8. import java.io.FileWriter;
  9. import java.io.IOException;
  10. import java.io.UnsupportedEncodingException;
  11. import java.util.Arrays;
  12. import java.util.Scanner;
  13.  
  14. import org.encog.ConsoleStatusReportable;
  15. import org.encog.Encog;
  16. import org.encog.engine.network.activation.ActivationBipolarSteepenedSigmoid;
  17. import org.encog.engine.network.activation.ActivationClippedLinear;
  18. import org.encog.engine.network.activation.ActivationCompetitive;
  19. import org.encog.engine.network.activation.ActivationElliott;
  20. import org.encog.engine.network.activation.ActivationGaussian;
  21. import org.encog.engine.network.activation.ActivationSigmoid;
  22. import org.encog.engine.network.activation.ActivationSoftMax;
  23. import org.encog.engine.network.activation.ActivationSteepenedSigmoid;
  24. import org.encog.engine.network.activation.ActivationStep;
  25. import org.encog.ml.MLRegression;
  26. import org.encog.ml.data.MLData;
  27. import org.encog.ml.data.MLDataPair;
  28. import org.encog.ml.data.MLDataSet;
  29. import org.encog.ml.data.basic.BasicMLDataSet;
  30. import org.encog.ml.data.versatile.NormalizationHelper;
  31. import org.encog.ml.data.versatile.VersatileMLDataSet;
  32. import org.encog.ml.data.versatile.columns.ColumnDefinition;
  33. import org.encog.ml.data.versatile.columns.ColumnType;
  34. import org.encog.ml.data.versatile.sources.CSVDataSource;
  35. import org.encog.ml.data.versatile.sources.VersatileDataSource;
  36. import org.encog.ml.factory.MLMethodFactory;
  37. import org.encog.ml.model.EncogModel;
  38. import org.encog.neural.networks.BasicNetwork;
  39. import org.encog.neural.networks.layers.BasicLayer;
  40. import org.encog.neural.networks.training.propagation.back.Backpropagation;
  41. import org.encog.neural.networks.training.propagation.quick.QuickPropagation;
  42. import org.encog.neural.networks.training.propagation.resilient.ResilientPropagation;
  43. import org.encog.util.csv.CSVFormat;
  44. import org.encog.util.csv.ReadCSV;
  45. import org.encog.util.simple.EncogUtility;
  46.  
  47. public class App {
  48.  
  49.     public static String VRIJEDNOSTI[][] = { { "pas", "pas" },
  50.             { "cucak", "pas" }, { "pesek", "pas" }, { "pes", "pas" },
  51.             { "peso", "pas" }, { "macka", "macka" }, { "maca", "macka" },
  52.             { "mica", "macka" }, { "traktor", "traktor" },
  53.             { "trakac", "traktor" }, { "deutz", "traktor" },
  54.             { "john deere", "traktor" } };
  55.  
  56.     /**
  57.      * The input necessary for XOR.
  58.      */
  59.     public static double ULAZNI_POJMOVI[][];
  60.  
  61.     /**
  62.      * The ideal data necessary for XOR.
  63.      */
  64.     public static double IDEALNI_IZLAZI[][];
  65.  
  66.     public static int NAJVECI;
  67.  
  68.     public static void main(final String args[]) {
  69.  
  70.         NAJVECI = 0;
  71.         for (int i = 0; i < VRIJEDNOSTI.length; i++) {
  72.             if (VRIJEDNOSTI[i][0].length() > NAJVECI)
  73.                 NAJVECI = VRIJEDNOSTI[i][0].length();
  74.         }
  75.  
  76.         // System.out.println(najveci);
  77.  
  78.         ULAZNI_POJMOVI = new double[VRIJEDNOSTI.length][NAJVECI];
  79.         IDEALNI_IZLAZI = new double[VRIJEDNOSTI.length][1];
  80.  
  81.         // dodavanje ulaza (slova) i normalizacija
  82.         for (int i = 0; i < VRIJEDNOSTI.length; i++) {
  83.             Arrays.fill(ULAZNI_POJMOVI[i], 0);
  84.             ULAZNI_POJMOVI[i] = toAscii(VRIJEDNOSTI[i][0], NAJVECI);
  85.             for (int j = 0; j < ULAZNI_POJMOVI[i].length; j++) {
  86.                 // System.out.print(ULAZNI_POJMOVI[i][j] + ", ");
  87.             }
  88.             System.out.println("Prosjeci ideal:");
  89.  
  90.             IDEALNI_IZLAZI[i][0] = average(toAscii(VRIJEDNOSTI[i][1], NAJVECI));
  91.             System.out.println(IDEALNI_IZLAZI[i][0]);
  92.         }
  93.  
  94.         // kreiranje neuronske mreze
  95.         BasicNetwork network = new BasicNetwork();
  96.         network.addLayer(new BasicLayer(null, true, NAJVECI));
  97.         network.addLayer(new BasicLayer(new ActivationSoftMax(), true, 3));
  98.         //network.addLayer(new BasicLayer(new ActivationSigmoid(), false, 3));
  99.         network.addLayer(new BasicLayer(new ActivationSoftMax(), false, 1));
  100.         network.getStructure().finalizeStructure();
  101.         network.reset();
  102.  
  103.         // set za treniranje
  104.         MLDataSet trainingSet = new BasicMLDataSet(ULAZNI_POJMOVI,
  105.                 IDEALNI_IZLAZI);
  106.  
  107.         // treniranje
  108.         final QuickPropagation train = new QuickPropagation(network, trainingSet);
  109.  
  110.         int epoch = 1;
  111.  
  112.         do {
  113.             train.iteration();
  114.             System.out
  115.                     .println("Epoch #" + epoch + " Error:" + train.getError());
  116.             epoch++;
  117.         } while (train.getError() > 0.00001);
  118.         train.finishTraining();
  119.  
  120.         double tocni = 0;
  121.        
  122.         // testiranje neuronske mreze
  123.         System.out.println("Neural Network Results:");
  124.         for (MLDataPair pair : trainingSet) {
  125.             final MLData output = network.compute(pair.getInput());
  126.  
  127.             System.out.println(pair.getInput().getData(0) + ", actual="
  128.                     + output.getData(0) + " , "
  129.                     + denormaliziraj(output.getData(0)) + " ,ideal="
  130.                     + pair.getIdeal().getData(0) + " , "
  131.                     + denormaliziraj(pair.getIdeal().getData(0)));
  132.            
  133.             if(denormaliziraj(output.getData(0)) == denormaliziraj(pair.getIdeal().getData(0)))
  134.             tocni++;
  135.         }
  136.        
  137.         System.out.println("Uspješnost: " + (tocni/IDEALNI_IZLAZI.length)*100 + "%");
  138.        
  139.         Encog.getInstance().shutdown();
  140.        
  141.     }
  142.  
  143.     public static double average(double[] ascii) {
  144.         double sum = 0;
  145.         for (double val : ascii)
  146.             sum += val;
  147.         return sum / ascii.length;
  148.     }
  149.  
  150.     // zbraja ascii vrijednosti svakog slova u stringu
  151.     public static double[] toAscii(String s, int najveci) {
  152.         double[] ascii = new double[najveci];
  153.         try {
  154.             byte[] bytes = s.getBytes("US-ASCII");
  155.             for (int i = 0; i < bytes.length; i++) {
  156.                 ascii[i] = 90.0 / bytes[i];
  157.             }
  158.  
  159.         } catch (UnsupportedEncodingException e) {
  160.             e.printStackTrace();
  161.         }
  162.         return ascii;
  163.     }
  164.  
  165.     // denormalizira double vrijednosti u string
  166.     public static String denormaliziraj(double trenutna) {
  167.         double najmanjaRazlika = 1;
  168.         int indeks = 0;
  169.         for (int i = 0; i < IDEALNI_IZLAZI.length; i++) {
  170.             if (Math.abs(IDEALNI_IZLAZI[i][0] - trenutna) < najmanjaRazlika) {
  171.                 najmanjaRazlika = Math.abs(IDEALNI_IZLAZI[i][0] - trenutna);
  172.                 indeks = i;
  173.             }
  174.         }
  175.         return VRIJEDNOSTI[indeks][1];
  176.     }
  177. }
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement