Guest User

Untitled

a guest
Nov 20th, 2017
98
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
text 1.33 KB | None | 0 0
  1. import numpy as np
  2. import tensorflow as tf
  3.  
  4. print 'tf_version: ', tf.__version__ # it is 1.4.0 right now
  5. np.set_printoptions(linewidth=150, precision=3, suppress=True)
  6.  
  7. M = 10
  8. d = 2
  9. # samples
  10. X = tf.constant(np.random.randn(M, d), 'float32')
  11. # ids of samples, say each sample have different id
  12. Y = tf.constant(np.expand_dims(np.random.permutation(M), 1), 'float32')
  13.  
  14. dset_items = (X,Y)
  15. # first dimensions must match
  16. # also they should be at least 2 rank
  17. first_dims = [item.shape.as_list()[0] for item in dset_items]
  18. assert np.all(np.equal(first_dims, first_dims[0]))
  19.  
  20. batch_size = 2
  21. n_epochs = 5
  22. dset = tf.data.Dataset.from_tensor_slices(dset_items)
  23. dset = dset.shuffle(M)
  24. dset = dset.repeat(n_epochs)
  25. dset = dset.batch(batch_size)
  26. dset = dset.prefetch(2)
  27. dset_iterator = dset.make_initializable_iterator()
  28. next_batch = dset_iterator.get_next()
  29.  
  30. sess = tf.Session()
  31. sess.run(tf.global_variables_initializer())
  32. sess.run(dset_iterator.initializer)
  33.  
  34. occurrence = np.zeros([M], 'int32')
  35.  
  36. it = 0
  37. while True:
  38. try:
  39. it += 1
  40. xb, yb = sess.run(next_batch)
  41. occurrence[np.int32(yb)] += 1
  42. print '%03d, x:%s, y:%s ' % (it, str(xb.ravel()), str(yb.ravel()))
  43. except tf.errors.OutOfRangeError:
  44. print 'end of dataset'
  45. break
  46.  
  47. sess.close()
  48.  
  49. print 'occurrence array:', occurrence # all entries should be M, indicating each sample is fetched M times
Add Comment
Please, Sign In to add comment