Advertisement
Guest User

Untitled

a guest
Dec 10th, 2016
69
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
text 2.06 KB | None | 0 0
  1. # Benchmark transferring data from TF into Python runtime
  2. # requirement: tf 0.12 (for var.read_value(), ones_initializer())
  3. #
  4. # On Linux default malloc is slow
  5. # sudo apt-get install google-perftools
  6. # export LD_PRELOAD="/usr/lib/libtcmalloc.so.4"
  7. #
  8. # 2014 MacBook:
  9. # 128MB -- 3.56 GB/s
  10. # 1024MB -- 1.96 GB/s
  11. #
  12. # Xeon E5-2630 v3 @ 2.40GHz:
  13. # 128 MB -- 0.43 GB/s (default malloc)
  14. # 128 MB -- 4-6.2 GB/s (tcmalloc)
  15. # 1024 MB -- 4-5.97 GB/s (tcmalloc)
  16.  
  17. import gc
  18. import os
  19. import subprocess
  20. import sys
  21. import tensorflow as tf
  22. import threading
  23. import time
  24.  
  25. flags = tf.flags
  26. flags.DEFINE_integer("iters", 10, "Maximum number of additions")
  27. flags.DEFINE_integer("warmup_iters", 5, "warmup iterations")
  28. flags.DEFINE_integer("data_mb", 128, "size of vector in MBs")
  29. flags.DEFINE_boolean("verbose", False, "extra logging")
  30. flags.DEFINE_boolean("sanity_check", False, "run sanity check on results")
  31. FLAGS = flags.FLAGS
  32.  
  33. def default_config():
  34. optimizer_options = tf.OptimizerOptions(opt_level=tf.OptimizerOptions.L0)
  35. config = tf.ConfigProto(
  36. graph_options=tf.GraphOptions(optimizer_options=optimizer_options))
  37. config.log_device_placement = False
  38. config.allow_soft_placement = False
  39. return config
  40.  
  41. def benchmark():
  42. gc.disable()
  43.  
  44. dtype = tf.int32
  45. params_size = 250*1000*FLAGS.data_mb # 1MB is 250k integers
  46. params = tf.get_variable("params", [params_size], dtype,
  47. initializer=tf.ones_initializer())
  48. params_read = params.read_value() # prevent caching
  49. init_op = tf.initialize_all_variables()
  50. sess = tf.Session(config=default_config())
  51. sess.run(init_op)
  52.  
  53. total = 0
  54. for i in range(FLAGS.iters+FLAGS.warmup_iters):
  55. if i == FLAGS.warmup_iters:
  56. start_time = time.time()
  57. # fetch value into Python runtime, and discard value immediately
  58. result = sess.run(params_read)
  59. if FLAGS.sanity_check:
  60. total += result.sum()
  61. print(float(total)/params_size)
  62.  
  63. elapsed_time = time.time() - start_time
  64. rate = float(FLAGS.iters)*FLAGS.data_mb/elapsed_time
  65. print("%.2f MB per second" % (rate))
  66.  
  67. if __name__ == '__main__':
  68. benchmark()
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement