Not a member of Pastebin yet?
Sign Up,
it unlocks many cool features!
- /*
- * To change this license header, choose License Headers in Project Properties.
- * To change this template file, choose Tools | Templates
- * and open the template in the editor.
- */
- package nnbiocatplugin;
- import annotool.classify.SavableClassifier;
- import java.io.FileInputStream;
- import java.io.FileOutputStream;
- import java.io.IOException;
- import java.io.ObjectInputStream;
- import java.io.ObjectOutputStream;
- import java.util.HashMap;
- import java.util.TreeSet;
- import java.util.logging.Level;
- import java.util.logging.Logger;
- import neuralnet.NetworkTester;
- import neuralnet.NetworkTrainer;
- import neuralnet.NeuralNet;
- /**
- *
- * @author Sean Vogel
- */
- public class NeuralNetClassifier implements SavableClassifier {
- private int epochs = 0;
- private int hiddenNodes = 0;
- //private int inputNodes = 0;
- private int outputNodes = 0;
- private double learningRate = 0.1;
- private NeuralNet net = null;
- @Override
- public Object trainingOnly(float[][] patterns, int[] targets) throws Exception {
- //if(net == null) throw new Exception("Parameters must be set before calling trainOnly method");
- TreeSet<Integer> ts = new TreeSet<>();
- for(int t : targets) { ts.add(t); }
- outputNodes = ts.size();
- System.out.println("inputNodes = " + patterns[0].length);
- System.out.println("outputNodes = " + outputNodes);
- net = new NeuralNet(patterns[0].length, hiddenNodes, outputNodes);
- NetworkTrainer trainer = new NetworkTrainer(net);
- for(int i = 0; i < epochs; ++i) {
- trainer.train(patterns, targets);
- }
- return net;
- }
- @Override
- public Object getModel() {
- return net;
- }
- @Override
- public void setModel(Object o) throws Exception {
- if(o instanceof NeuralNet)
- net = (NeuralNet)o;
- }
- @Override
- public int classifyUsingModel(Object o, float[] testingpattern, double[] prob) throws Exception {
- int prediction = -1;
- if(o instanceof NeuralNet) {
- NetworkTester tester = new NetworkTester((NeuralNet)o);
- prediction = tester.test(testingpattern);
- }
- return prediction;
- }
- @Override
- public int[] classifyUsingModel(Object nn, float[][] testingpatterns, double[] prob) throws Exception {
- int[] predictions = null;
- if(nn instanceof NeuralNet) {
- predictions = new int[testingpatterns.length];
- NetworkTester tester = new NetworkTester((NeuralNet)nn);
- tester.test(testingpatterns, predictions);
- }
- return predictions;
- }
- @Override
- public void saveModel(Object o, String file_name) throws IOException {
- if(o instanceof NeuralNet) {
- FileOutputStream fout = new FileOutputStream(file_name);
- ObjectOutputStream oos = new ObjectOutputStream(fout);
- oos.writeObject(o);
- }
- }
- @Override
- public Object loadModel(String file_name) throws IOException {
- NeuralNet n = null;
- try {
- FileInputStream fin = new FileInputStream(file_name);
- ObjectInputStream ois = new ObjectInputStream(fin);
- n = (NeuralNet)ois.readObject();
- } catch (ClassNotFoundException ex) {
- Logger.getLogger(NeuralNetClassifier.class.getName()).log(Level.SEVERE, null, ex);
- }
- return n;
- }
- @Override
- public void setParameters(HashMap<String, String> hm) {
- //int in = Integer.parseInt(hm.get("inputNodes"));
- hiddenNodes = Integer.parseInt(hm.get("Hidden Nodes"));
- //int out = Integer.parseInt(hm.get("outputNodes"));
- epochs = Integer.parseInt(hm.get("Epochs"));
- learningRate = Double.parseDouble(hm.get("Learning Rate"));
- //System.out.println("Params: hiddenNodes="+hiddenNodes+" epochs="+epochs);
- //net = new NeuralNet(in, hid, out);
- }
- @Override
- public void classify(float[][] trainingpatterns, int[] targets, float[][] testingpatterns, int[] predictions, double[] prob) throws Exception {
- //if(net == null) throw new Exception("Parameters must be set before calling classify method");
- TreeSet<Integer> ts = new TreeSet<>();
- for(int t : targets) { ts.add(t); }
- outputNodes = ts.size();
- System.out.println("inputNodes = " + trainingpatterns[0].length);
- System.out.println("outputNodes = " + outputNodes);
- net = new NeuralNet(trainingpatterns[0].length, hiddenNodes, outputNodes);
- NetworkTrainer trainer = new NetworkTrainer(net, 0.01);
- for(int i = 0; i < epochs; ++i) {
- trainer.train(trainingpatterns, targets);
- }
- NetworkTester tester = new NetworkTester(net);
- tester.test(testingpatterns, predictions);
- }
- @Override
- public boolean doesSupportProbability() {
- return false;
- }
- }
Advertisement
Add Comment
Please, Sign In to add comment