Advertisement
Guest User

Untitled

a guest
Jun 27th, 2017
50
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
text 1.47 KB | None | 0 0
  1. def main(args: Array[String]): Unit = {
  2. val spark = SparkSession.builder().master("local").appName("disaggregation_test_task_000").getOrCreate()
  3. val fileSystem = FileSystem.get(spark.sparkContext.hadoopConfiguration)
  4. val conf = new NeuralNetConfiguration.Builder().weightInit(WeightInit.XAVIER).updater(Updater.NESTEROVS).learningRate(0.1).list.layer(0, new DenseLayer.Builder().nIn(4).nOut(3).activation(Activation.TANH).build).layer(1, new OutputLayer.Builder(LossFunctions.LossFunction.NEGATIVELOGLIKELIHOOD).activation(Activation.SOFTMAX).nIn(3).nOut(3).build).backprop(true).pretrain(false).build
  5.  
  6. val tm = new ParameterAveragingTrainingMaster.Builder(8)
  7. .exportDirectory("/tmp/dn4j")
  8. .averagingFrequency(5)
  9. .workerPrefetchNumBatches(2)
  10. .batchSizePerWorker(8)
  11. .build()
  12. val sparkNet = new SparkDl4jMultiLayer(spark.sparkContext, conf, tm)
  13. val net = sparkNet.getNetwork
  14. net.init()
  15. val model_path = "D:/nn_model/model.zip"
  16. val bos = new BufferedOutputStream(fileSystem.create(new Path(model_path)))
  17. val oos = new ObjectOutputStream(bos)
  18. ModelSerializer.writeModel(net, oos, false)
  19.  
  20. val ois = new ObjectInputStream(new BufferedInputStream(fileSystem.open(new Path(model_path))))
  21. val restored = ModelSerializer.restoreMultiLayerNetwork(ois)
  22. //println("Saved and loaded parameters are equal: " + restored.params)
  23. //println("Saved and loaded configurations are equal: " + restored.getLayerWiseConfigurations)
  24. }
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement