naraku9333

NeuralNetClassifier

Dec 4th, 2015
133
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
Java 4.98 KB | None | 0 0
  1. /*
  2.  * To change this license header, choose License Headers in Project Properties.
  3.  * To change this template file, choose Tools | Templates
  4.  * and open the template in the editor.
  5.  */
  6. package nnbiocatplugin;
  7.  
  8. import annotool.classify.SavableClassifier;
  9. import java.io.FileInputStream;
  10. import java.io.FileOutputStream;
  11. import java.io.IOException;
  12. import java.io.ObjectInputStream;
  13. import java.io.ObjectOutputStream;
  14. import java.util.HashMap;
  15. import java.util.TreeSet;
  16. import java.util.logging.Level;
  17. import java.util.logging.Logger;
  18. import neuralnet.NetworkTester;
  19. import neuralnet.NetworkTrainer;
  20. import neuralnet.NeuralNet;
  21.  
  22. /**
  23.  *
  24.  * @author Sean Vogel
  25.  */
  26. public class NeuralNetClassifier implements SavableClassifier {
  27.  
  28.     private int epochs = 0;
  29.     private int hiddenNodes = 0;
  30.     //private int inputNodes = 0;
  31.     private int outputNodes = 0;
  32.     private double learningRate = 0.1;
  33.     private NeuralNet net = null;
  34.    
  35.     @Override
  36.     public Object trainingOnly(float[][] patterns, int[] targets) throws Exception {
  37.         //if(net == null) throw new Exception("Parameters must be set before calling trainOnly method");
  38.         TreeSet<Integer> ts = new TreeSet<>();
  39.         for(int t : targets) { ts.add(t); }
  40.         outputNodes = ts.size();
  41.         System.out.println("inputNodes = " + patterns[0].length);
  42.         System.out.println("outputNodes = " + outputNodes);
  43.        
  44.         net = new NeuralNet(patterns[0].length, hiddenNodes, outputNodes);
  45.         NetworkTrainer trainer = new NetworkTrainer(net);
  46.         for(int i = 0; i < epochs; ++i) {
  47.             trainer.train(patterns, targets);
  48.         }
  49.        
  50.         return net;
  51.     }
  52.  
  53.     @Override
  54.     public Object getModel() {
  55.         return net;
  56.     }
  57.  
  58.     @Override
  59.     public void setModel(Object o) throws Exception {
  60.         if(o instanceof NeuralNet)
  61.             net = (NeuralNet)o;
  62.     }
  63.  
  64.     @Override
  65.     public int classifyUsingModel(Object o, float[] testingpattern, double[] prob) throws Exception {
  66.         int prediction = -1;
  67.         if(o instanceof NeuralNet) {
  68.             NetworkTester tester = new NetworkTester((NeuralNet)o);
  69.             prediction = tester.test(testingpattern);
  70.         }
  71.         return prediction;
  72.     }
  73.  
  74.     @Override
  75.     public int[] classifyUsingModel(Object nn, float[][] testingpatterns, double[] prob) throws Exception {
  76.         int[] predictions = null;
  77.         if(nn instanceof NeuralNet) {
  78.             predictions = new int[testingpatterns.length];
  79.             NetworkTester tester = new NetworkTester((NeuralNet)nn);
  80.             tester.test(testingpatterns, predictions);
  81.         }
  82.         return predictions;
  83.     }
  84.  
  85.     @Override
  86.     public void saveModel(Object o, String file_name) throws IOException {
  87.         if(o instanceof NeuralNet) {
  88.             FileOutputStream fout = new FileOutputStream(file_name);
  89.             ObjectOutputStream oos = new ObjectOutputStream(fout);  
  90.             oos.writeObject(o);
  91.         }
  92.     }
  93.  
  94.     @Override
  95.     public Object loadModel(String file_name) throws IOException {
  96.         NeuralNet n = null;
  97.         try {
  98.             FileInputStream fin = new FileInputStream(file_name);
  99.             ObjectInputStream ois = new ObjectInputStream(fin);
  100.             n = (NeuralNet)ois.readObject();
  101.         } catch (ClassNotFoundException ex) {
  102.             Logger.getLogger(NeuralNetClassifier.class.getName()).log(Level.SEVERE, null, ex);
  103.         }
  104.         return n;
  105.     }
  106.  
  107.     @Override
  108.     public void setParameters(HashMap<String, String> hm) {
  109.         //int in = Integer.parseInt(hm.get("inputNodes"));
  110.         hiddenNodes = Integer.parseInt(hm.get("Hidden Nodes"));
  111.         //int out = Integer.parseInt(hm.get("outputNodes"));
  112.         epochs = Integer.parseInt(hm.get("Epochs"));
  113.         learningRate = Double.parseDouble(hm.get("Learning Rate"));
  114.         //System.out.println("Params: hiddenNodes="+hiddenNodes+" epochs="+epochs);
  115.         //net = new NeuralNet(in, hid, out);
  116.     }
  117.  
  118.     @Override
  119.     public void classify(float[][] trainingpatterns, int[] targets, float[][] testingpatterns, int[] predictions, double[] prob) throws Exception {
  120.         //if(net == null) throw new Exception("Parameters must be set before calling classify method");
  121.         TreeSet<Integer> ts = new TreeSet<>();
  122.         for(int t : targets) { ts.add(t); }
  123.         outputNodes = ts.size();
  124.         System.out.println("inputNodes = " + trainingpatterns[0].length);
  125.         System.out.println("outputNodes = " + outputNodes);
  126.        
  127.         net = new NeuralNet(trainingpatterns[0].length, hiddenNodes, outputNodes);
  128.         NetworkTrainer trainer = new NetworkTrainer(net, 0.01);
  129.         for(int i = 0; i < epochs; ++i) {
  130.             trainer.train(trainingpatterns, targets);
  131.         }
  132.        
  133.         NetworkTester tester = new NetworkTester(net);
  134.         tester.test(testingpatterns, predictions);
  135.     }
  136.  
  137.     @Override
  138.     public boolean doesSupportProbability() {
  139.         return false;
  140.     }    
  141. }
Advertisement
Add Comment
Please, Sign In to add comment