Advertisement
vonnik

Neural Network + Regression for Sin Function

Oct 19th, 2015
3,796
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
Java 2.73 KB | None | 0 0
  1. package org.deeplearning4j.examples.regression
  2.  
  3. import org.canova.api.records.reader.RecordReader
  4. import org.canova.api.records.reader.impl.CSVRecordReader
  5. import org.canova.api.split.FileSplit
  6. import org.deeplearning4j.datasets.canova.RecordReaderDataSetIterator
  7. import org.deeplearning4j.datasets.iterator.DataSetIterator
  8. import org.deeplearning4j.nn.api.OptimizationAlgorithm
  9. import org.deeplearning4j.nn.conf.{ MultiLayerConfiguration, NeuralNetConfiguration }
  10. import org.deeplearning4j.nn.conf.Updater
  11. import org.deeplearning4j.nn.conf.layers.{DenseLayer, OutputLayer}
  12. import org.deeplearning4j.nn.multilayer.MultiLayerNetwork
  13. import org.deeplearning4j.nn.weights.WeightInit
  14. import org.deeplearning4j.optimize.listeners.ScoreIterationListener
  15. import org.nd4j.linalg.api.ndarray.INDArray
  16. import org.nd4j.linalg.api.ops.impl.transforms.Sin
  17. import org.nd4j.linalg.dataset.DataSet
  18. import org.nd4j.linalg.factory.Nd4j
  19. import org.nd4j.linalg.lossfunctions.LossFunctions
  20. import scalax.chart.api._
  21. import org.nd4s.Implicits._
  22.  
  23. object RegressionExample {
  24.   def main(args: Array[String]) = {
  25.  
  26.     def plotXY(x:INDArray, y:INDArray):Unit = {
  27.  
  28.       val dataPlot = for(i <- 0 to y.length()-1) yield (x.getFloat(i), y.getFloat(i))
  29.       val chart = XYLineChart(dataPlot)
  30.       chart.show()
  31.     }
  32.  
  33.  
  34.     Nd4j.ENFORCE_NUMERICAL_STABILITY = true
  35.     val numInputs = 1
  36.     val numOutputs = 1
  37.     val numHiddenNodes = 20
  38.     val nSamples = 1000
  39.     val x = Nd4j.linspace(-10, 10, nSamples).reshape(nSamples, 1)
  40.     val y = Nd4j.getExecutioner().execAndReturn(new Sin(x, x.dup())).div(x)
  41.     //val y = x.mul(2)
  42.     val dataSet = new DataSet(x, y)
  43.  
  44.     val seed = 0
  45.     val iterations = 5000
  46.     val numEpochs = 1
  47.     val learningRate = 100
  48.  
  49.     //plotXY(x, y)
  50.  
  51.     val conf = new NeuralNetConfiguration.Builder()
  52.       .seed(seed)
  53.       .constrainGradientToUnitNorm(false)
  54.       .iterations(iterations)
  55.       .optimizationAlgo(OptimizationAlgorithm.LINE_GRADIENT_DESCENT)
  56.       .learningRate(learningRate)
  57.       .list(2)
  58.       .layer(0, new DenseLayer.Builder().nIn(numInputs).nOut(numHiddenNodes)
  59.         .weightInit(WeightInit.XAVIER)
  60.         .activation("tanh")
  61.         .build())
  62.       .layer(1, new OutputLayer.Builder(LossFunctions.LossFunction.MSE)
  63.         .weightInit(WeightInit.XAVIER)
  64.         .activation("identity").weightInit(WeightInit.XAVIER)
  65.         .nIn(numHiddenNodes).nOut(numOutputs).build()).backprop(true)
  66.       .build()
  67.  
  68.     val network = new MultiLayerNetwork(conf)
  69.     network.init()
  70.     network.setListeners(new ScoreIterationListener(1))
  71.     network.fit(dataSet)
  72.     val yEst = network.output(x)
  73.     val error = Nd4j.std(y.sub(yEst))
  74.  
  75.     println("Z" + yEst)
  76.     println("Y" + y)
  77.  
  78.     plotXY(x, yEst)
  79.   }
  80. }
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement