Advertisement
Not a member of Pastebin yet?
Sign Up,
it unlocks many cool features!
- import org.deeplearning4j.nn.conf.ComputationGraphConfiguration;
- import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
- import org.deeplearning4j.nn.conf.inputs.InputType;
- import org.deeplearning4j.nn.conf.layers.LSTM;
- import org.deeplearning4j.nn.conf.layers.RnnOutputLayer;
- import org.deeplearning4j.nn.graph.ComputationGraph;
- import org.nd4j.linalg.activations.Activation;
- import org.nd4j.linalg.api.ndarray.INDArray;
- import org.nd4j.linalg.factory.Nd4j;
- import org.nd4j.linalg.indexing.INDArrayIndex;
- import org.nd4j.linalg.indexing.NDArrayIndex;
- import org.nd4j.linalg.learning.config.Adam;
- import org.nd4j.linalg.lossfunctions.LossFunctions;
- import java.util.Map;
- public class NaNProblem {
- public static void main(String... args){
- int batchSize = 5;
- int inputSize = 300;
- int hiddenSize = 100;
- int questionLen = 10;
- INDArray features = Nd4j.rand(new int[] {batchSize, inputSize, questionLen});
- INDArray labels = Nd4j.zeros(1, questionLen);
- labels.get(new INDArrayIndex[] {NDArrayIndex.point(0), NDArrayIndex.point(2)}).assign(1);
- labels = labels.reshape(new int[] {1, 1, questionLen}).repeat(0, (long) batchSize);
- INDArray labelsMask = Nd4j.zeros(1, questionLen);
- //NOTE: If labels' mask is all ones, we don't get NaNs
- labelsMask.get(new INDArrayIndex[] {NDArrayIndex.point(0), NDArrayIndex.interval(0,4)}).assign(1);
- labelsMask = labelsMask.reshape(new int[] {1, 1, questionLen}).repeat(0, (long) batchSize);
- ComputationGraphConfiguration configuration = new NeuralNetConfiguration.Builder()
- .updater(new Adam())
- .graphBuilder()
- .addInputs("features")
- .setInputTypes(InputType.recurrent(inputSize))
- .addLayer("lstm", new LSTM.Builder().nIn(inputSize).nOut(hiddenSize).build(), "features")
- //Cross Entropy loss function causes NaN
- .addLayer("output", new RnnOutputLayer.Builder().nIn(hiddenSize).nOut(1).activation(Activation.TANH).lossFunction(LossFunctions.LossFunction.RECONSTRUCTION_CROSSENTROPY).build(), "lstm")
- //MSE loss function doesn't cause NaN
- // .addLayer("output", new RnnOutputLayer.Builder().nIn(hiddenSize).nOut(1).activation(Activation.TANH).lossFunction(LossFunctions.LossFunction.MSE).build(), "lstm")
- .setOutputs("output")
- .allowDisconnected(true)
- .build();
- ComputationGraph net = new ComputationGraph(configuration);
- net.init();
- net.fit(new INDArray[] {features}, new INDArray[] {labels}, null, new INDArray[] {labelsMask});
- Map<String, INDArray> results = net.feedForward(new INDArray[]{features}, false);
- System.out.println(results.get("output"));
- //Network output is NaN
- }
- }
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement