Advertisement
Guest User

Untitled

a guest
Sep 29th, 2016
64
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
text 3.07 KB | None | 0 0
  1. package com.jaloo.data;
  2.  
  3. import org.nd4j.linalg.api.ndarray.INDArray;
  4. import org.nd4j.linalg.api.rng.DefaultRandom;
  5. import org.nd4j.linalg.dataset.DataSet;
  6. import org.nd4j.linalg.dataset.api.DataSetPreProcessor;
  7. import org.nd4j.linalg.dataset.api.iterator.DataSetIterator;
  8. import org.nd4j.linalg.factory.Nd4j;
  9. import org.nd4j.linalg.indexing.INDArrayIndex;
  10. import org.nd4j.linalg.indexing.NDArrayIndex;
  11.  
  12. import java.util.List;
  13. import java.util.function.Consumer;
  14.  
  15. /**
  16. * Created by paul on 9/29/16.
  17. */
  18. public class StandaloneDataIterator implements DataSetIterator {
  19. private int vectorSize = 4;
  20. private int lengthWindow = 128;
  21. private int lengthWindowOut = 35;
  22. private int cursor = 0;
  23. private int numSamples = 100;
  24. private int batchSize = 50;
  25. private int numLabels = 200;
  26.  
  27. @Override
  28. public boolean hasNext() {
  29. return cursor < numSamples;
  30. }
  31.  
  32. @Override
  33. public DataSet next() {
  34. return next(batchSize);
  35. }
  36.  
  37. @Override
  38. public void remove() {
  39. throw new UnsupportedOperationException("Not implemented");
  40. }
  41.  
  42. @Override
  43. public void forEachRemaining(Consumer<? super DataSet> action) {
  44. throw new UnsupportedOperationException("Not implemented");
  45. }
  46.  
  47. @Override
  48. public DataSet next(int batchSize) {
  49. INDArray features = Nd4j.create(batchSize, vectorSize * lengthWindow, 'c');
  50. INDArray labels = Nd4j.create(batchSize, totalOutcomes(), lengthWindowOut);
  51. INDArray labelsMask = Nd4j.zeros(batchSize, lengthWindowOut);
  52.  
  53. DefaultRandom defaultRandom = new DefaultRandom(12345);
  54. int shape[] = {1, vectorSize};
  55.  
  56. for( int i=0; i<batchSize; i++ ) {
  57. for(int j=0;j<lengthWindow;j++) {
  58. INDArray row = defaultRandom.nextDouble(shape);
  59. for(int k=0;k<vectorSize;k++){
  60. features.put(new INDArrayIndex[]{NDArrayIndex.point(i), NDArrayIndex.point(j*vectorSize+k)},row.getDouble(0, k));
  61. }
  62. }
  63. int lastIdx = lengthWindowOut - 1;
  64. labels.putScalar(new int[]{i,0,lastIdx},1.0);
  65. labelsMask.putScalar(new int[]{i,lastIdx},1.0);
  66. }
  67. cursor += batchSize;
  68. return new DataSet(features,labels, null, labelsMask);
  69. }
  70.  
  71.  
  72. @Override
  73. public int totalExamples() {
  74. return numSamples;
  75. }
  76.  
  77. @Override
  78. public int inputColumns() {
  79. return vectorSize*lengthWindow;
  80. }
  81.  
  82. @Override
  83. public int totalOutcomes() {
  84. return numLabels;
  85. }
  86.  
  87. @Override
  88. public boolean resetSupported() {
  89. return true;
  90. }
  91.  
  92. @Override
  93. public boolean asyncSupported() {
  94. return false;
  95. }
  96.  
  97. @Override
  98. public void reset() {
  99. cursor = 0;
  100. }
  101.  
  102. @Override
  103. public int batch() {
  104. return batchSize;
  105. }
  106.  
  107. @Override
  108. public int cursor() {
  109. return cursor;
  110. }
  111.  
  112. @Override
  113. public int numExamples() {
  114. return totalExamples();
  115. }
  116.  
  117. @Override
  118. public void setPreProcessor(DataSetPreProcessor dataSetPreProcessor) {
  119. throw new UnsupportedOperationException("Not implemented");
  120. }
  121.  
  122. @Override
  123. public DataSetPreProcessor getPreProcessor() {
  124. throw new UnsupportedOperationException("Not implemented");
  125. }
  126.  
  127. @Override
  128. public List<String> getLabels() {
  129. throw new UnsupportedOperationException("Not implemented");
  130. }
  131. }
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement