Not a member of Pastebin yet?
Sign Up,
it unlocks many cool features!
- import org.apache.commons.io.FilenameUtils
- import org.apache.spark.{SparkConf, SparkContext}
- import org.deeplearning4j.datasets.iterator.impl.MnistDataSetIterator
- import org.deeplearning4j.nn.conf.MultiLayerConfiguration
- import org.deeplearning4j.nn.conf.NeuralNetConfiguration
- import org.deeplearning4j.nn.conf.inputs.InputType
- import org.deeplearning4j.nn.conf.layers._
- import org.deeplearning4j.nn.multilayer.MultiLayerNetwork
- import org.deeplearning4j.nn.weights.WeightInit
- import org.deeplearning4j.optimize.api.InvocationType
- import org.deeplearning4j.optimize.listeners.EvaluativeListener
- import org.deeplearning4j.optimize.listeners.ScoreIterationListener
- import org.nd4j.linalg.activations.Activation
- import org.nd4j.linalg.dataset.api.iterator.DataSetIterator
- import org.nd4j.linalg.learning.config.Adam
- import org.nd4j.linalg.lossfunctions.LossFunctions
- import org.nd4j.linalg.learning.config
- import org.nd4j.linalg.learning.config.IUpdater
- import org.slf4j.Logger
- import org.slf4j.LoggerFactory
- import java.io.File
- import java.io.File
- object SimpleApp {
- def main(args: Array[String]) {
- // val conf = new SparkConf().setAppName("Spark Pi")
- //val spark = new SparkContext(conf)
- //val slices = if (args.length > 0) args(0).toInt else 2
- //val n = math.min(100000L * slices, Int.MaxValue).toInt // avoid overflow
- //val count = spark.parallelize(1 until n, slices).map { i =>
- //val count = 7
- // val x = random * 2 - 1
- // val y = random * 2 - 1
- // if (x*x + y*y < 1) 1 else 0
- //}.reduce(_ + _)
- //println("Pi is roughly " + 4.0 * count / n)
- //spark.stop()
- val nChannels = 1 // Number of input channels
- val outputNum = 10 // The number of possible outcomes
- val batchSize = 64 // Test batch size
- val nEpochs = 1 // Number of training epochs
- val seed = 123 //
- /*
- Create an iterator using the batch size for one iteration
- */ println("Load data....")
- val mnistTrain = new MnistDataSetIterator(batchSize, true, 12345)
- val mnistTest = new MnistDataSetIterator(batchSize, false, 12345)
- /*
- Construct the neural network
- */ println("Build model....")
- val conf = new NeuralNetConfiguration.Builder().seed(seed).l2(0.0005).weightInit(WeightInit.XAVIER).updater(new Adam(1e-3)).list
- .layer(new ConvolutionLayer.Builder(5, 5).nIn //nIn and nOut specify depth. nIn here is the nChannels and nOut is the number of filters to be applied
- (nChannels).stride(1, 1).nOut(20).activation(Activation.IDENTITY).build)
- .layer(new SubsamplingLayer.Builder(PoolingType.MAX).kernelSize(2, 2).stride(2, 2).build)
- .layer(new ConvolutionLayer.Builder(5, 5).stride //Note that nIn need not be specified in later layers
- (1, 1).nOut(50).activation(Activation.IDENTITY).build)
- .layer(new SubsamplingLayer.Builder(PoolingType.MAX).kernelSize(2, 2).stride(2, 2).build)
- .layer(new DenseLayer.Builder().activation(Activation.RELU).nOut(500).build)
- .layer(new OutputLayer.Builder(LossFunctions.LossFunction.NEGATIVELOGLIKELIHOOD).nOut(outputNum).activation(Activation.SOFTMAX).build)
- .setInputType(InputType.convolutionalFlat(28, 28, 1)).build //See note below
- /*
- Regarding the .setInputType(InputType.convolutionalFlat(28,28,1)) line: This does a few things.
- (a) It adds preprocessors, which handle things like the transition between the convolutional/subsampling layers
- and the dense layer
- (b) Does some additional configuration validation
- (c) Where necessary, sets the nIn (number of input neurons, or input depth in the case of CNNs) values for each
- layer based on the size of the previous layer (but it won't override values manually set by the user)
- InputTypes can be used with other layer types too (RNNs, MLPs etc) not just CNNs.
- For normal images (when using ImageRecordReader) use InputType.convolutional(height,width,depth).
- MNIST record reader is a special case, that outputs 28x28 pixel grayscale (nChannels=1) images, in a "flattened"
- row vector format (i.e., 1x784 vectors), hence the "convolutionalFlat" input type used here.
- */ val model = new MultiLayerNetwork(conf)
- model.init()
- println("Train model...")
- model.setListeners(new ScoreIterationListener(10), new EvaluativeListener(mnistTest, 1, InvocationType.EPOCH_END)) //Print score every 10 iterations and evaluate on test set every epoch
- model.fit(mnistTrain, nEpochs)
- val path = FilenameUtils.concat(System.getProperty("java.io.tmpdir"), "lenetmnist.zip")
- println("Saving model to tmp folder: " + path)
- //model.save(new File(path), true)
- println("****************Example finished********************")
- }
- }
Advertisement
Add Comment
Please, Sign In to add comment