Advertisement
Not a member of Pastebin yet?
Sign Up,
it unlocks many cool features!
- package com.jaloo.data;
- import org.nd4j.linalg.api.ndarray.INDArray;
- import org.nd4j.linalg.api.rng.DefaultRandom;
- import org.nd4j.linalg.dataset.DataSet;
- import org.nd4j.linalg.dataset.api.DataSetPreProcessor;
- import org.nd4j.linalg.dataset.api.iterator.DataSetIterator;
- import org.nd4j.linalg.factory.Nd4j;
- import org.nd4j.linalg.indexing.INDArrayIndex;
- import org.nd4j.linalg.indexing.NDArrayIndex;
- import java.util.List;
- import java.util.function.Consumer;
- /**
- * Created by paul on 9/29/16.
- */
- public class StandaloneDataIterator implements DataSetIterator {
- private int vectorSize = 4;
- private int lengthWindow = 128;
- private int lengthWindowOut = 35;
- private int cursor = 0;
- private int numSamples = 100;
- private int batchSize = 50;
- private int numLabels = 200;
- @Override
- public boolean hasNext() {
- return cursor < numSamples;
- }
- @Override
- public DataSet next() {
- return next(batchSize);
- }
- @Override
- public void remove() {
- throw new UnsupportedOperationException("Not implemented");
- }
- @Override
- public void forEachRemaining(Consumer<? super DataSet> action) {
- throw new UnsupportedOperationException("Not implemented");
- }
- @Override
- public DataSet next(int batchSize) {
- INDArray features = Nd4j.create(batchSize, vectorSize * lengthWindow, 'c');
- INDArray labels = Nd4j.create(batchSize, totalOutcomes(), lengthWindowOut);
- INDArray labelsMask = Nd4j.zeros(batchSize, lengthWindowOut);
- DefaultRandom defaultRandom = new DefaultRandom(12345);
- int shape[] = {1, vectorSize};
- for( int i=0; i<batchSize; i++ ) {
- for(int j=0;j<lengthWindow;j++) {
- INDArray row = defaultRandom.nextDouble(shape);
- for(int k=0;k<vectorSize;k++){
- features.put(new INDArrayIndex[]{NDArrayIndex.point(i), NDArrayIndex.point(j*vectorSize+k)},row.getDouble(0, k));
- }
- }
- int lastIdx = lengthWindowOut - 1;
- labels.putScalar(new int[]{i,0,lastIdx},1.0);
- labelsMask.putScalar(new int[]{i,lastIdx},1.0);
- }
- cursor += batchSize;
- return new DataSet(features,labels, null, labelsMask);
- }
- @Override
- public int totalExamples() {
- return numSamples;
- }
- @Override
- public int inputColumns() {
- return vectorSize*lengthWindow;
- }
- @Override
- public int totalOutcomes() {
- return numLabels;
- }
- @Override
- public boolean resetSupported() {
- return true;
- }
- @Override
- public boolean asyncSupported() {
- return false;
- }
- @Override
- public void reset() {
- cursor = 0;
- }
- @Override
- public int batch() {
- return batchSize;
- }
- @Override
- public int cursor() {
- return cursor;
- }
- @Override
- public int numExamples() {
- return totalExamples();
- }
- @Override
- public void setPreProcessor(DataSetPreProcessor dataSetPreProcessor) {
- throw new UnsupportedOperationException("Not implemented");
- }
- @Override
- public DataSetPreProcessor getPreProcessor() {
- throw new UnsupportedOperationException("Not implemented");
- }
- @Override
- public List<String> getLabels() {
- throw new UnsupportedOperationException("Not implemented");
- }
- }
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement