lemueltra

treina_10_05

May 11th, 2016
270
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
Java 3.29 KB | None | 0 0
  1. package rna_seguidor;
  2.  
  3. import java.io.File;
  4.  
  5. import org.encog.engine.network.activation.ActivationSigmoid;
  6. import org.encog.mathutil.randomize.ConsistentRandomizer;
  7. import org.encog.ml.data.MLData;
  8. import org.encog.ml.data.MLDataPair;
  9. import org.encog.ml.data.MLDataSet;
  10. import org.encog.ml.data.basic.BasicMLDataSet;
  11. import org.encog.neural.networks.BasicNetwork;
  12. import org.encog.neural.networks.layers.BasicLayer;
  13. import org.encog.neural.networks.training.Train;
  14. import org.encog.neural.networks.training.propagation.back.Backpropagation;
  15. import org.encog.persist.EncogDirectoryPersistence;
  16.  
  17. public class treina {
  18.     public static double X[][] = { { 0.0, 0.0, 0.0 }, //Dados de entrada para treinamento
  19.                                    { 0.0, 0.0, 0.0 },
  20.                                    { 0.0, 0.0, 0.0 },
  21.                                    { 0.0, 0.0, 0.0 },
  22.                                    { 0.0, 0.0, 0.0 },
  23.                                    { 0.0, 0.0, 0.0 },
  24.                                    { 0.0, 0.0, 0.0 },
  25.                                    { 0.0, 0.0, 0.0 }
  26.                                  };
  27.     public static double Y[][] = { { 0.0, 0.0 }, //Dados de saída para treinamento
  28.                                    { 0.0, 0.0 },
  29.                                    { 0.0, 0.0 },
  30.                                    { 0.0, 0.0 },
  31.                                    { 0.0, 0.0 },
  32.                                    { 0.0, 0.0 },
  33.                                    { 0.0, 0.0 },
  34.                                    { 0.0, 0.0 }
  35.                                  };
  36.    
  37.     public static void main(final String args[]) {
  38.         final BasicNetwork rna = new BasicNetwork(); //Cria rede neural feedforward
  39.         rna.addLayer(new BasicLayer(new ActivationSigmoid(), true, 3)); //Camada de entrada com 3 neurônios (função de ativação tangente sigmoidal e bias ativado)
  40.         rna.addLayer(new BasicLayer(new ActivationSigmoid(), true, 5)); //Camada de neurônios ocultos com 5 neurônios
  41.         rna.addLayer(new BasicLayer(new ActivationSigmoid(), false, 2)); //Camada de saída com 2 neurônios
  42.         rna.getStructure().finalizeStructure(); //Finaliza a configuração da rede
  43.         rna.reset();
  44.        
  45.         new ConsistentRandomizer(-1,1,500).randomize(rna); //Valores aleatórios nos pesos (valor mínimo, valor máximo, seed)
  46.         System.out.println("Valores iniciais dos pesos:\n" + rna.dumpWeights() + "\n");
  47.  
  48.         MLDataSet data_treina = new BasicMLDataSet(X, Y); //Setar arquivos de treinamento
  49.        
  50.         //Treinamento em backpropagation
  51.         final Train treina = new Backpropagation(rna,data_treina,0.7, 0.9); //(rede a ser treinada, arquivos de treinamento, taxa de aprendizagem, taxa de momentum)
  52.         int epoca = 1;
  53.        
  54.         System.out.println("Treinamento:");
  55.         do {  
  56.             treina.iteration(); //Backpropagation expõe os dados à rede
  57.             System.out.println("Época: " + epoca + " | Erro:" + treina.getError());
  58.             epoca++;
  59.         } while((epoca < 5000) && (treina.getError() > 0.001)); //A rede irá treinar até quando a taxa de erro for menor que 0,1% ou atingir 5000 épocas
  60.        
  61.         System.out.println("\nResultados do Treinamento da RNA:\n");
  62.         for (MLDataPair data_pair: data_treina) { //Data_pair recebe os dados de treinamento
  63.             final MLData saida = rna.compute(data_pair.getInput()); //Saida recebe os dados de saída da RNA
  64.             System.out.println("Saída 1 da RNA= " + saida.getData(0) + " | Valor ideal= " + data_pair.getIdeal().getData(0));
  65.             System.out.println("Saída 2 da RNA= " + saida.getData(1) + " | Valor ideal= " + data_pair.getIdeal().getData(1) + "\n");
  66.         }
  67.         EncogDirectoryPersistence.saveObject(new File("rna_seguidor.eg"), rna); //Salva o treinamento em um arquivo .eg
  68.     }
  69. }
Add Comment
Please, Sign In to add comment