cromat1

Untitled

Nov 29th, 2015
143
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
Java 6.88 KB | None | 0 0
  1. /*
  2.  * Encog(tm) Java Examples v3.3
  3.  * http://www.heatonresearch.com/encog/
  4.  * https://github.com/encog/encog-java-examples
  5.  *
  6.  * Copyright 2008-2014 Heaton Research, Inc.
  7.  *
  8.  * Licensed under the Apache License, Version 2.0 (the "License");
  9.  * you may not use this file except in compliance with the License.
  10.  * You may obtain a copy of the License at
  11.  *
  12.  *     http://www.apache.org/licenses/LICENSE-2.0
  13.  *
  14.  * Unless required by applicable law or agreed to in writing, software
  15.  * distributed under the License is distributed on an "AS IS" BASIS,
  16.  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  17.  * See the License for the specific language governing permissions and
  18.  * limitations under the License.
  19.  *  
  20.  * For more information on Heaton Research copyrights, licenses
  21.  * and trademarks visit:
  22.  * http://www.heatonresearch.com/copyright
  23.  */
  24. package org.encog.examples.guide.classification;
  25.  
  26. import java.io.BufferedReader;
  27. import java.io.File;
  28. import java.io.FileInputStream;
  29. import java.io.InputStreamReader;
  30. import java.net.MalformedURLException;
  31. import java.net.URL;
  32. import java.util.Arrays;
  33.  
  34. import org.encog.ConsoleStatusReportable;
  35. import org.encog.Encog;
  36. import org.encog.bot.BotUtil;
  37. import org.encog.ml.MLRegression;
  38. import org.encog.ml.data.MLData;
  39. import org.encog.ml.data.versatile.NormalizationHelper;
  40. import org.encog.ml.data.versatile.VersatileMLDataSet;
  41. import org.encog.ml.data.versatile.columns.ColumnDefinition;
  42. import org.encog.ml.data.versatile.columns.ColumnType;
  43. import org.encog.ml.data.versatile.sources.CSVDataSource;
  44. import org.encog.ml.data.versatile.sources.VersatileDataSource;
  45. import org.encog.ml.factory.MLMethodFactory;
  46. import org.encog.ml.model.EncogModel;
  47. import org.encog.util.csv.CSVFormat;
  48. import org.encog.util.csv.ReadCSV;
  49. import org.encog.util.simple.EncogUtility;
  50.  
  51. public class IrisClassification {
  52.     public static String DATA_URL = "https://archive.ics.uci.edu/ml/machine-learning-databases/iris/iris.data";
  53.  
  54.     private String tempPath;
  55.  
  56.     public File downloadData(String[] args) throws MalformedURLException {
  57.         if (args.length != 0) {
  58.             tempPath = args[0];
  59.         } else {
  60.             tempPath = System.getProperty("java.io.tmpdir");
  61.         }
  62.  
  63.         File irisFile = new File(tempPath, "iris.csv");
  64.         BotUtil.downloadPage(new URL(IrisClassification.DATA_URL), irisFile);
  65.         System.out.println("Downloading Iris dataset to: " + irisFile);
  66.         return irisFile;
  67.     }
  68.  
  69.    
  70.    
  71.     public void run(String[] args) {
  72.         try {
  73.             // Download the data that we will attempt to model.
  74.             //File irisFile = downloadData(args);
  75.            
  76.             // Define the format of the data file.
  77.             // This area will change, depending on the columns and
  78.             // format of the file that you are trying to model.
  79.            
  80.             File pojmoviFile = new File("C:\\pojmovi.csv");
  81.            
  82.            
  83.             VersatileDataSource source = new CSVDataSource(pojmoviFile, false,
  84.                     CSVFormat.DECIMAL_POINT);
  85.             VersatileMLDataSet data = new VersatileMLDataSet(source);
  86.             data.defineSourceColumn("sepal-length", 0, ColumnType.continuous);
  87.             data.defineSourceColumn("sepal-width", 1, ColumnType.continuous);
  88.             data.defineSourceColumn("petal-length", 2, ColumnType.continuous);
  89.             data.defineSourceColumn("petal-width", 3, ColumnType.continuous);
  90.            
  91.             // Define the column that we are trying to predict.
  92.             ColumnDefinition outputColumn = data.defineSourceColumn("species", 4,
  93.                     ColumnType.nominal);
  94.            
  95.             // Analyze the data, determine the min/max/mean/sd of every column.
  96.             data.analyze();
  97.            
  98.             // Map the prediction column to the output of the model, and all
  99.             // other columns to the input.
  100.             data.defineSingleOutputOthersInput(outputColumn);
  101.            
  102.             // Create feedforward neural network as the model type. MLMethodFactory.TYPE_FEEDFORWARD.
  103.             // You could also other model types, such as:
  104.             // MLMethodFactory.SVM:  Support Vector Machine (SVM)
  105.             // MLMethodFactory.TYPE_RBFNETWORK: RBF Neural Network
  106.             // MLMethodFactor.TYPE_NEAT: NEAT Neural Network
  107.             // MLMethodFactor.TYPE_PNN: Probabilistic Neural Network
  108.             EncogModel model = new EncogModel(data);
  109.             model.selectMethod(data, MLMethodFactory.TYPE_FEEDFORWARD);
  110.            
  111.             // Send any output to the console.
  112.             model.setReport(new ConsoleStatusReportable());
  113.            
  114.             // Now normalize the data.  Encog will automatically determine the correct normalization
  115.             // type based on the model you chose in the last step.
  116.             data.normalize();
  117.            
  118.             // Hold back some data for a final validation.
  119.             // Shuffle the data into a random ordering.
  120.             // Use a seed of 1001 so that we always use the same holdback and will get more consistent results.
  121.             model.holdBackValidation(0.3, true, 1001);
  122.            
  123.             // Choose whatever is the default training type for this model.
  124.             model.selectTrainingType(data);
  125.            
  126.             // Use a 5-fold cross-validated train.  Return the best method found.
  127.             MLRegression bestMethod = (MLRegression)model.crossvalidate(5, true);
  128.  
  129.             // Display the training and validation errors.
  130.             System.out.println( "Training error: " + EncogUtility.calculateRegressionError(bestMethod, model.getTrainingDataset()));
  131.             System.out.println( "Validation error: " + EncogUtility.calculateRegressionError(bestMethod, model.getValidationDataset()));
  132.            
  133.             // Display our normalization parameters.
  134.             NormalizationHelper helper = data.getNormHelper();
  135.             System.out.println(helper.toString());
  136.            
  137.             // Display the final model.
  138.             System.out.println("Final model: " + bestMethod);
  139.            
  140.             // Loop over the entire, original, dataset and feed it through the model.
  141.             // This also shows how you would process new data, that was not part of your
  142.             // training set.  You do not need to retrain, simply use the NormalizationHelper
  143.             // class.  After you train, you can save the NormalizationHelper to later
  144.             // normalize and denormalize your data.
  145.             ReadCSV csv = new ReadCSV(pojmoviFile, false, CSVFormat.DECIMAL_POINT);
  146.             String[] line = new String[4];
  147.             MLData input = helper.allocateInputVector();
  148.            
  149.             while(csv.next()) {
  150.                 StringBuilder result = new StringBuilder();
  151.                 line[0] = csv.get(0);
  152.                 line[1] = csv.get(1);
  153.                 line[2] = csv.get(2);
  154.                 line[3] = csv.get(3);
  155.                 String correct = csv.get(4);
  156.                 helper.normalizeInputVector(line,input.getData(),false);
  157.                 MLData output = bestMethod.compute(input);
  158.                 String irisChosen = helper.denormalizeOutputVectorToString(output)[0];
  159.                
  160.                 result.append(Arrays.toString(line));
  161.                 result.append(" -> predicted: ");
  162.                 result.append(irisChosen);
  163.                 result.append("(correct: ");
  164.                 result.append(correct);
  165.                 result.append(")");
  166.                
  167.                 System.out.println(result.toString());
  168.             }
  169.            
  170.             // Delete data file ande shut down.
  171.             //irisFile.delete();
  172.             Encog.getInstance().shutdown();
  173.  
  174.         } catch (Exception ex) {
  175.             ex.printStackTrace();
  176.         }
  177.     }
  178.  
  179.     public static void main(String[] args) {
  180.         IrisClassification prg = new IrisClassification();
  181.         prg.run(args);
  182.     }
  183. }
Advertisement
Add Comment
Please, Sign In to add comment