Advertisement
Not a member of Pastebin yet?
Sign Up,
it unlocks many cool features!
- package org.deeplearning4j.examples.regression
- import org.canova.api.records.reader.RecordReader
- import org.canova.api.records.reader.impl.CSVRecordReader
- import org.canova.api.split.FileSplit
- import org.deeplearning4j.datasets.canova.RecordReaderDataSetIterator
- import org.deeplearning4j.datasets.iterator.DataSetIterator
- import org.deeplearning4j.nn.api.OptimizationAlgorithm
- import org.deeplearning4j.nn.conf.{ MultiLayerConfiguration, NeuralNetConfiguration }
- import org.deeplearning4j.nn.conf.Updater
- import org.deeplearning4j.nn.conf.layers.{DenseLayer, OutputLayer}
- import org.deeplearning4j.nn.multilayer.MultiLayerNetwork
- import org.deeplearning4j.nn.weights.WeightInit
- import org.deeplearning4j.optimize.listeners.ScoreIterationListener
- import org.nd4j.linalg.api.ndarray.INDArray
- import org.nd4j.linalg.api.ops.impl.transforms.Sin
- import org.nd4j.linalg.dataset.DataSet
- import org.nd4j.linalg.factory.Nd4j
- import org.nd4j.linalg.lossfunctions.LossFunctions
- import scalax.chart.api._
- import org.nd4s.Implicits._
- object RegressionExample {
- def main(args: Array[String]) = {
- def plotXY(x:INDArray, y:INDArray):Unit = {
- val dataPlot = for(i <- 0 to y.length()-1) yield (x.getFloat(i), y.getFloat(i))
- val chart = XYLineChart(dataPlot)
- chart.show()
- }
- Nd4j.ENFORCE_NUMERICAL_STABILITY = true
- val numInputs = 1
- val numOutputs = 1
- val numHiddenNodes = 20
- val nSamples = 1000
- val x = Nd4j.linspace(-10, 10, nSamples).reshape(nSamples, 1)
- val y = Nd4j.getExecutioner().execAndReturn(new Sin(x, x.dup())).div(x)
- //val y = x.mul(2)
- val dataSet = new DataSet(x, y)
- val seed = 0
- val iterations = 5000
- val numEpochs = 1
- val learningRate = 100
- //plotXY(x, y)
- val conf = new NeuralNetConfiguration.Builder()
- .seed(seed)
- .constrainGradientToUnitNorm(false)
- .iterations(iterations)
- .optimizationAlgo(OptimizationAlgorithm.LINE_GRADIENT_DESCENT)
- .learningRate(learningRate)
- .list(2)
- .layer(0, new DenseLayer.Builder().nIn(numInputs).nOut(numHiddenNodes)
- .weightInit(WeightInit.XAVIER)
- .activation("tanh")
- .build())
- .layer(1, new OutputLayer.Builder(LossFunctions.LossFunction.MSE)
- .weightInit(WeightInit.XAVIER)
- .activation("identity").weightInit(WeightInit.XAVIER)
- .nIn(numHiddenNodes).nOut(numOutputs).build()).backprop(true)
- .build()
- val network = new MultiLayerNetwork(conf)
- network.init()
- network.setListeners(new ScoreIterationListener(1))
- network.fit(dataSet)
- val yEst = network.output(x)
- val error = Nd4j.std(y.sub(yEst))
- println("Z" + yEst)
- println("Y" + y)
- plotXY(x, yEst)
- }
- }
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement