Guest User

Untitled

a guest
Feb 19th, 2018
137
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
text 2.18 KB | None | 0 0
  1. from fastai.learner import *
  2.  
  3. import torchtext
  4. from torchtext import vocab, data
  5. from torchtext.datasets import language_modeling
  6.  
  7. from fastai.rnn_reg import *
  8. from fastai.rnn_train import *
  9. from fastai.nlp import *
  10. from fastai.lm_rnn import *
  11.  
  12. import dill as pickle
  13.  
  14.  
  15. PATH="data/"
  16. TRN_PATH = 'train/'
  17. VAL_PATH = 'valid/'
  18. TST_PATH = 'test/'
  19. TRN = f'{PATH}{TRN_PATH}'
  20. VAL = f'{PATH}{VAL_PATH}'
  21. TST = f'{PATH}{TST_PATH}'
  22.  
  23.  
  24. TEXT = data.Field(lower=True, tokenize=spacy_tok)
  25.  
  26. bs=64
  27. bptt=70
  28.  
  29. FILES = dict(train=TRN_PATH, validation=VAL_PATH, test=TST_PATH)
  30.  
  31. md = LanguageModelData.from_text_files(PATH, TEXT, **FILES, bs=bs, bptt=bptt, min_freq=10)
  32.  
  33. pickle.dump(TEXT, open(f'{PATH}models/TEXT.pkl','wb'))
  34.  
  35. len(md.trn_dl), md.nt, len(md.trn_ds), len(md.trn_ds[0].text)
  36.  
  37. em_sz = 200 # size of each embedding vector
  38. nh = 500 # number of hidden activations per layer
  39. nl = 3 # number of layers
  40.  
  41. opt_fn = partial(optim.Adam, betas=(0.7, 0.99))
  42.  
  43. learner = md.get_model(opt_fn, em_sz, nh, nl,
  44. dropouti=0.05, dropout=0.05, wdrop=0.1, dropoute=0.02, dropouth=0.05)
  45. learner.reg_fn = partial(seq2seq_reg, alpha=2, beta=1)
  46. learner.clip=0.3
  47.  
  48. # it's overfitting a bit by the later epochs :/
  49. learner.fit(3e-3, 4, wds=1e-6, cycle_len=1, cycle_mult=2)
  50.  
  51. # save model and pickle vocab
  52. learner.save_encoder('austen_adam_enc')
  53. pickle.dump(TEXT, open(f'{PATH}models/TEXT.pkl','wb'))
  54.  
  55.  
  56. # now time to try to make it write stuff
  57. learner.load_encoder('austen_adam_enc')
  58.  
  59. from numpy.random import choice
  60.  
  61. def generate_token(txt):
  62. m = learner.model
  63. s = [spacy_tok(txt)]
  64. t=TEXT.numericalize(s)
  65. # Set batch size to 1
  66. m[0].bs=1
  67. # Turn off dropout
  68. m.eval()
  69. # Reset hidden state
  70. m.reset()
  71. # Get predictions from model
  72. res,*_ = m(t)
  73.  
  74. nexts = torch.topk(res[-1], 10)
  75.  
  76. unnormalized_probs = to_np(nexts[0])
  77. sum_probs = sum(unnormalized_probs)
  78. probs = [p / sum_probs for p in unnormalized_probs]
  79.  
  80. draw = choice(to_np(nexts[1]), 2, p=probs)
  81.  
  82. d = draw[0] if draw[0] != 0 else draw[1]
  83. w = TEXT.vocab.itos[d]
  84. # reset batch size
  85. m[0].bs=bs
  86.  
  87. return w
  88.  
  89. s = "elizabeth"
  90. for i in range(0, 100):
  91. tok = generate_token(s)
  92. s += ' ' + tok
  93. if tok == '.':
  94. break
  95.  
  96. s
Add Comment
Please, Sign In to add comment