Guest User

Untitled

a guest
Feb 25th, 2018
90
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
text 1.92 KB | None | 0 0
  1. #!/usr/bin/env python3
  2. from tensor2tensor.data_generators import text_encoder
  3.  
  4. import tensorflow as tf
  5. import sys
  6.  
  7. flags = tf.flags
  8. FLAGS = flags.FLAGS
  9.  
  10. flags.DEFINE_string("vocab", None, "Path to the subword vocabulary")
  11. flags.DEFINE_string("src", None, "Path to the source-language text")
  12. flags.DEFINE_string("trg", None, "Path to the target-language text")
  13. # TODO print the actual subwords, use vocab._subtoken_id_to_subtoken_string() instead of _subtoken_ids_to_tokens()
  14. flags.DEFINE_bool("print", False, "Print a character for each subword?")
  15.  
  16. def eprint(*args, **kwargs):
  17. print(*args, file=sys.stderr, **kwargs)
  18.  
  19. def words_subwords(vocab, string):
  20. #subwords = vocab._subtoken_ids_to_tokens([x]) for x in vocab.encode(string)]
  21. n_words = len(string.split())
  22. n_subwords = len(vocab.encode(string))
  23. return n_words, n_subwords
  24.  
  25. s_words, t_words, m_words = 0, 0, 0
  26. s_subws, t_subws, m_subws = 0, 0, 0
  27. sents = 0
  28.  
  29. def print_stats():
  30. global s_words, t_words, m_words, s_subws, t_subws, m_subws, sents
  31. eprint("\ntotal: sents=%d words=%d subwords=%s subwords/words %.4f" % (sents, m_words, m_subws, m_subws/m_words))
  32. eprint("source: words=%d subwords=%d" % (s_words, s_subws))
  33. eprint("target: words=%d subwords=%d" % (t_words, t_subws))
  34.  
  35. def main(_):
  36. global s_words, t_words, m_words, s_subws, t_subws, m_subws, sents
  37. vocab = text_encoder.SubwordTextEncoder(FLAGS.vocab)
  38. with open(FLAGS.src, encoding="utf-8") as src, open(FLAGS.trg, encoding="utf-8") as trg:
  39. for s, t in zip(src, trg):
  40. sents += 1
  41. s = s.strip()
  42. t = t.strip()
  43. s_w, s_s = words_subwords(vocab, s)
  44. t_w, t_s = words_subwords(vocab, t)
  45. s_words += s_w
  46. t_words += t_w
  47. m_words += max(s_w, t_w)
  48. s_subws += s_s
  49. t_subws += t_s
  50. m_subws += max(s_s, t_s)
  51. if sents % 100000 == 0:
  52. print_stats()
  53. if FLAGS.print:
  54. print("a" * max(s_s, t_s))
  55. print_stats()
  56.  
  57.  
  58. if __name__ == "__main__":
  59. tf.app.run()
Add Comment
Please, Sign In to add comment