Guest User

Untitled

a guest
Apr 10th, 2021
105
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
text 4.77 KB | None | 0 0
  1. import org.apache.commons.io.FilenameUtils
  2. import org.apache.spark.{SparkConf, SparkContext}
  3. import org.deeplearning4j.datasets.iterator.impl.MnistDataSetIterator
  4. import org.deeplearning4j.nn.conf.MultiLayerConfiguration
  5. import org.deeplearning4j.nn.conf.NeuralNetConfiguration
  6. import org.deeplearning4j.nn.conf.inputs.InputType
  7. import org.deeplearning4j.nn.conf.layers._
  8. import org.deeplearning4j.nn.multilayer.MultiLayerNetwork
  9. import org.deeplearning4j.nn.weights.WeightInit
  10. import org.deeplearning4j.optimize.api.InvocationType
  11. import org.deeplearning4j.optimize.listeners.EvaluativeListener
  12. import org.deeplearning4j.optimize.listeners.ScoreIterationListener
  13. import org.nd4j.linalg.activations.Activation
  14. import org.nd4j.linalg.dataset.api.iterator.DataSetIterator
  15. import org.nd4j.linalg.learning.config.Adam
  16. import org.nd4j.linalg.lossfunctions.LossFunctions
  17. import org.nd4j.linalg.learning.config
  18. import org.nd4j.linalg.learning.config.IUpdater
  19. import org.slf4j.Logger
  20. import org.slf4j.LoggerFactory
  21.  
  22. import java.io.File
  23. import java.io.File
  24.  
  25.  
  26. object SimpleApp {
  27. def main(args: Array[String]) {
  28. // val conf = new SparkConf().setAppName("Spark Pi")
  29. //val spark = new SparkContext(conf)
  30. //val slices = if (args.length > 0) args(0).toInt else 2
  31. //val n = math.min(100000L * slices, Int.MaxValue).toInt // avoid overflow
  32. //val count = spark.parallelize(1 until n, slices).map { i =>
  33. //val count = 7
  34. // val x = random * 2 - 1
  35. // val y = random * 2 - 1
  36. // if (x*x + y*y < 1) 1 else 0
  37. //}.reduce(_ + _)
  38. //println("Pi is roughly " + 4.0 * count / n)
  39. //spark.stop()
  40. val nChannels = 1 // Number of input channels
  41. val outputNum = 10 // The number of possible outcomes
  42. val batchSize = 64 // Test batch size
  43. val nEpochs = 1 // Number of training epochs
  44. val seed = 123 //
  45. /*
  46. Create an iterator using the batch size for one iteration
  47. */ println("Load data....")
  48. val mnistTrain = new MnistDataSetIterator(batchSize, true, 12345)
  49. val mnistTest = new MnistDataSetIterator(batchSize, false, 12345)
  50. /*
  51. Construct the neural network
  52. */ println("Build model....")
  53. val conf = new NeuralNetConfiguration.Builder().seed(seed).l2(0.0005).weightInit(WeightInit.XAVIER).updater(new Adam(1e-3)).list
  54. .layer(new ConvolutionLayer.Builder(5, 5).nIn //nIn and nOut specify depth. nIn here is the nChannels and nOut is the number of filters to be applied
  55. (nChannels).stride(1, 1).nOut(20).activation(Activation.IDENTITY).build)
  56. .layer(new SubsamplingLayer.Builder(PoolingType.MAX).kernelSize(2, 2).stride(2, 2).build)
  57. .layer(new ConvolutionLayer.Builder(5, 5).stride //Note that nIn need not be specified in later layers
  58. (1, 1).nOut(50).activation(Activation.IDENTITY).build)
  59. .layer(new SubsamplingLayer.Builder(PoolingType.MAX).kernelSize(2, 2).stride(2, 2).build)
  60. .layer(new DenseLayer.Builder().activation(Activation.RELU).nOut(500).build)
  61. .layer(new OutputLayer.Builder(LossFunctions.LossFunction.NEGATIVELOGLIKELIHOOD).nOut(outputNum).activation(Activation.SOFTMAX).build)
  62. .setInputType(InputType.convolutionalFlat(28, 28, 1)).build //See note below
  63.  
  64. /*
  65. Regarding the .setInputType(InputType.convolutionalFlat(28,28,1)) line: This does a few things.
  66. (a) It adds preprocessors, which handle things like the transition between the convolutional/subsampling layers
  67. and the dense layer
  68. (b) Does some additional configuration validation
  69. (c) Where necessary, sets the nIn (number of input neurons, or input depth in the case of CNNs) values for each
  70. layer based on the size of the previous layer (but it won't override values manually set by the user)
  71. InputTypes can be used with other layer types too (RNNs, MLPs etc) not just CNNs.
  72. For normal images (when using ImageRecordReader) use InputType.convolutional(height,width,depth).
  73. MNIST record reader is a special case, that outputs 28x28 pixel grayscale (nChannels=1) images, in a "flattened"
  74. row vector format (i.e., 1x784 vectors), hence the "convolutionalFlat" input type used here.
  75. */ val model = new MultiLayerNetwork(conf)
  76. model.init()
  77. println("Train model...")
  78. model.setListeners(new ScoreIterationListener(10), new EvaluativeListener(mnistTest, 1, InvocationType.EPOCH_END)) //Print score every 10 iterations and evaluate on test set every epoch
  79.  
  80. model.fit(mnistTrain, nEpochs)
  81. val path = FilenameUtils.concat(System.getProperty("java.io.tmpdir"), "lenetmnist.zip")
  82. println("Saving model to tmp folder: " + path)
  83. //model.save(new File(path), true)
  84. println("****************Example finished********************")
  85. }
  86. }
Advertisement
Add Comment
Please, Sign In to add comment