Advertisement
Not a member of Pastebin yet?
Sign Up,
it unlocks many cool features!
- def main(args: Array[String]): Unit = {
- val spark = SparkSession.builder().master("local").appName("disaggregation_test_task_000").getOrCreate()
- val fileSystem = FileSystem.get(spark.sparkContext.hadoopConfiguration)
- val conf = new NeuralNetConfiguration.Builder().weightInit(WeightInit.XAVIER).updater(Updater.NESTEROVS).learningRate(0.1).list.layer(0, new DenseLayer.Builder().nIn(4).nOut(3).activation(Activation.TANH).build).layer(1, new OutputLayer.Builder(LossFunctions.LossFunction.NEGATIVELOGLIKELIHOOD).activation(Activation.SOFTMAX).nIn(3).nOut(3).build).backprop(true).pretrain(false).build
- val tm = new ParameterAveragingTrainingMaster.Builder(8)
- .exportDirectory("/tmp/dn4j")
- .averagingFrequency(5)
- .workerPrefetchNumBatches(2)
- .batchSizePerWorker(8)
- .build()
- val sparkNet = new SparkDl4jMultiLayer(spark.sparkContext, conf, tm)
- val net = sparkNet.getNetwork
- net.init()
- val model_path = "D:/nn_model/model.zip"
- val bos = new BufferedOutputStream(fileSystem.create(new Path(model_path)))
- val oos = new ObjectOutputStream(bos)
- ModelSerializer.writeModel(net, oos, false)
- val ois = new ObjectInputStream(new BufferedInputStream(fileSystem.open(new Path(model_path))))
- val restored = ModelSerializer.restoreMultiLayerNetwork(ois)
- //println("Saved and loaded parameters are equal: " + restored.params)
- //println("Saved and loaded configurations are equal: " + restored.getLayerWiseConfigurations)
- }
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement