Not a member of Pastebin yet?
Sign Up,
it unlocks many cool features!
- import numpy as np
- import tensorflow as tf
- print 'tf_version: ', tf.__version__ # it is 1.4.0 right now
- np.set_printoptions(linewidth=150, precision=3, suppress=True)
- M = 10
- d = 2
- # samples
- X = tf.constant(np.random.randn(M, d), 'float32')
- # ids of samples, say each sample have different id
- Y = tf.constant(np.expand_dims(np.random.permutation(M), 1), 'float32')
- dset_items = (X,Y)
- # first dimensions must match
- # also they should be at least 2 rank
- first_dims = [item.shape.as_list()[0] for item in dset_items]
- assert np.all(np.equal(first_dims, first_dims[0]))
- batch_size = 2
- n_epochs = 5
- dset = tf.data.Dataset.from_tensor_slices(dset_items)
- dset = dset.shuffle(M)
- dset = dset.repeat(n_epochs)
- dset = dset.batch(batch_size)
- dset = dset.prefetch(2)
- dset_iterator = dset.make_initializable_iterator()
- next_batch = dset_iterator.get_next()
- sess = tf.Session()
- sess.run(tf.global_variables_initializer())
- sess.run(dset_iterator.initializer)
- occurrence = np.zeros([M], 'int32')
- it = 0
- while True:
- try:
- it += 1
- xb, yb = sess.run(next_batch)
- occurrence[np.int32(yb)] += 1
- print '%03d, x:%s, y:%s ' % (it, str(xb.ravel()), str(yb.ravel()))
- except tf.errors.OutOfRangeError:
- print 'end of dataset'
- break
- sess.close()
- print 'occurrence array:', occurrence # all entries should be M, indicating each sample is fetched M times
Add Comment
Please, Sign In to add comment