Advertisement
Guest User

Untitled

a guest
Mar 19th, 2019
71
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
text 2.80 KB | None | 0 0
  1. import org.deeplearning4j.nn.conf.ComputationGraphConfiguration;
  2. import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
  3. import org.deeplearning4j.nn.conf.inputs.InputType;
  4. import org.deeplearning4j.nn.conf.layers.LSTM;
  5. import org.deeplearning4j.nn.conf.layers.RnnOutputLayer;
  6. import org.deeplearning4j.nn.graph.ComputationGraph;
  7. import org.nd4j.linalg.activations.Activation;
  8. import org.nd4j.linalg.api.ndarray.INDArray;
  9. import org.nd4j.linalg.factory.Nd4j;
  10. import org.nd4j.linalg.indexing.INDArrayIndex;
  11. import org.nd4j.linalg.indexing.NDArrayIndex;
  12. import org.nd4j.linalg.learning.config.Adam;
  13. import org.nd4j.linalg.lossfunctions.LossFunctions;
  14.  
  15. import java.util.Map;
  16.  
  17. public class NaNProblem {
  18.  
  19. public static void main(String... args){
  20. int batchSize = 5;
  21. int inputSize = 300;
  22. int hiddenSize = 100;
  23. int questionLen = 10;
  24.  
  25. INDArray features = Nd4j.rand(new int[] {batchSize, inputSize, questionLen});
  26.  
  27. INDArray labels = Nd4j.zeros(1, questionLen);
  28. labels.get(new INDArrayIndex[] {NDArrayIndex.point(0), NDArrayIndex.point(2)}).assign(1);
  29. labels = labels.reshape(new int[] {1, 1, questionLen}).repeat(0, (long) batchSize);
  30.  
  31. INDArray labelsMask = Nd4j.zeros(1, questionLen);
  32. //NOTE: If labels' mask is all ones, we don't get NaNs
  33. labelsMask.get(new INDArrayIndex[] {NDArrayIndex.point(0), NDArrayIndex.interval(0,4)}).assign(1);
  34. labelsMask = labelsMask.reshape(new int[] {1, 1, questionLen}).repeat(0, (long) batchSize);
  35.  
  36.  
  37. ComputationGraphConfiguration configuration = new NeuralNetConfiguration.Builder()
  38. .updater(new Adam())
  39. .graphBuilder()
  40. .addInputs("features")
  41. .setInputTypes(InputType.recurrent(inputSize))
  42.  
  43. .addLayer("lstm", new LSTM.Builder().nIn(inputSize).nOut(hiddenSize).build(), "features")
  44. //Cross Entropy loss function causes NaN
  45. .addLayer("output", new RnnOutputLayer.Builder().nIn(hiddenSize).nOut(1).activation(Activation.TANH).lossFunction(LossFunctions.LossFunction.RECONSTRUCTION_CROSSENTROPY).build(), "lstm")
  46. //MSE loss function doesn't cause NaN
  47. // .addLayer("output", new RnnOutputLayer.Builder().nIn(hiddenSize).nOut(1).activation(Activation.TANH).lossFunction(LossFunctions.LossFunction.MSE).build(), "lstm")
  48.  
  49. .setOutputs("output")
  50. .allowDisconnected(true)
  51. .build();
  52.  
  53. ComputationGraph net = new ComputationGraph(configuration);
  54. net.init();
  55.  
  56. net.fit(new INDArray[] {features}, new INDArray[] {labels}, null, new INDArray[] {labelsMask});
  57.  
  58. Map<String, INDArray> results = net.feedForward(new INDArray[]{features}, false);
  59.  
  60. System.out.println(results.get("output"));
  61. //Network output is NaN
  62. }
  63. }
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement