Advertisement
Not a member of Pastebin yet?
Sign Up,
it unlocks many cool features!
- package dl;
- import java.io.File;
- import java.io.IOException;
- import org.datavec.api.records.reader.RecordReader;
- import org.datavec.api.records.reader.impl.csv.CSVRecordReader;
- import org.datavec.api.split.FileSplit;
- import org.nd4j.linalg.activations.Activation;
- import org.nd4j.linalg.api.ndarray.INDArray;
- import org.nd4j.linalg.learning.config.Sgd;
- import org.nd4j.linalg.lossfunctions.LossFunctions.LossFunction;
- import org.nd4j.linalg.dataset.api.iterator.DataSetIterator;
- import org.nd4j.linalg.dataset.DataSet;
- import org.deeplearning4j.datasets.datavec.RecordReaderDataSetIterator;
- import org.deeplearning4j.eval.Evaluation;
- import org.deeplearning4j.nn.conf.MultiLayerConfiguration;
- import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
- import org.deeplearning4j.nn.conf.layers.Layer;
- import org.deeplearning4j.nn.conf.layers.DenseLayer;
- import org.deeplearning4j.nn.conf.layers.OutputLayer;
- import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
- import org.deeplearning4j.nn.weights.WeightInit;
- import org.deeplearning4j.optimize.listeners.ScoreIterationListener;
- import javax.xml.crypto.Data;
- public class App1_DL4J {
- static final int input_nodes = 784;
- static final int hidden_nodes = 200;
- static final int output_nodes = 10;
- static final double learning_rate = 0.1;
- static final int random_seed = 123;
- static final int batch_size = 32;
- static final int epochs = 2;
- public static void main(String[] args)throws IOException, InterruptedException {
- MultiLayerConfiguration configuration = configureNeuralNetwork();
- MultiLayerNetwork neuralNetwork = createNeuralNetworkModel(configuration);
- DataSetIterator itfile_train;
- String training_file_name = "data/mnist_train.csv";
- itfile_train = readCSV_File(training_file_name);
- trainNeuralNetwork(neuralNetwork, itfile_train);
- String evaluete_file_name = "data/mnist_test.csv";
- DataSetIterator itfile_evaluete;
- itfile_evaluete = readCSV_File(evaluete_file_name);
- evalueteNeuralNetwork(neuralNetwork, itfile_evaluete);
- }
- private static MultiLayerConfiguration configureNeuralNetwork()
- {
- int random_seed = 123;
- Layer Wih = configureWih();
- Layer Who = configureWho();
- return (new NeuralNetConfiguration.Builder()
- .seed(random_seed)
- .updater(new Sgd(learning_rate)) // Stochastic Descending Gradient
- .list()
- .layer(0, Wih)
- .layer(1, Who)
- .pretrain(false)
- .backprop(true)
- .build());
- }
- private static MultiLayerNetwork createNeuralNetworkModel(MultiLayerConfiguration configuration)
- {
- MultiLayerNetwork model = new MultiLayerNetwork(configuration);
- model.init();
- model.setListeners(new ScoreIterationListener(1));
- return model;
- }
- private static DataSetIterator readCSV_File(String nome_arquivo) throws IOException, InterruptedException
- {
- RecordReader record = new CSVRecordReader();
- record.initialize(new FileSplit(new File(nome_arquivo)));
- return new RecordReaderDataSetIterator (record, batch_size, 0, 10);
- }
- private static void trainNeuralNetwork(MultiLayerNetwork neuralNetwork, DataSetIterator itfile)
- {
- for(int i = 0; i < epochs; i++)
- {
- neuralNetwork.fit(itfile);
- }
- }
- public static void evalueteNeuralNetwork(MultiLayerNetwork model, DataSetIterator itfile)
- {
- Evaluation score = new Evaluation(10);
- while(itfile.hasNext()){
- DataSet next_data = itfile.next();
- INDArray array = model.output(next_data.getFeatures());
- score.eval(next_data.getLabels(), array);
- }
- System.out.println(score.stats());
- }
- private static Layer configureWih()
- {
- return(new DenseLayer.Builder()
- .nIn(input_nodes)
- .nOut(hidden_nodes)
- .activation(Activation.SIGMOID)
- .weightInit(WeightInit.NORMAL)
- .build() );
- }
- private static Layer configureWho()
- {
- return (new OutputLayer.Builder(LossFunction.MSE)
- .nIn(hidden_nodes)
- .nOut(output_nodes)
- .activation(Activation.SIGMOID)
- .weightInit(WeightInit.NORMAL)
- .build() );
- }
- }
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement