Need a unique gift idea?
A Pastebin account makes a great Christmas gift
SHARE
TWEET

Untitled

a guest Feb 19th, 2018 80 Never
Upgrade to PRO!
ENDING IN00days00hours00mins00secs
 
  1. from fastai.learner import *
  2.  
  3. import torchtext
  4. from torchtext import vocab, data
  5. from torchtext.datasets import language_modeling
  6.  
  7. from fastai.rnn_reg import *
  8. from fastai.rnn_train import *
  9. from fastai.nlp import *
  10. from fastai.lm_rnn import *
  11.  
  12. import dill as pickle
  13.  
  14.  
  15. PATH="data/"
  16. TRN_PATH = 'train/'
  17. VAL_PATH = 'valid/'
  18. TST_PATH = 'test/'
  19. TRN = f'{PATH}{TRN_PATH}'
  20. VAL = f'{PATH}{VAL_PATH}'
  21. TST = f'{PATH}{TST_PATH}'
  22.  
  23.  
  24. TEXT = data.Field(lower=True, tokenize=spacy_tok)
  25.  
  26. bs=64
  27. bptt=70
  28.  
  29. FILES = dict(train=TRN_PATH, validation=VAL_PATH, test=TST_PATH)
  30.  
  31. md = LanguageModelData.from_text_files(PATH, TEXT, **FILES, bs=bs, bptt=bptt, min_freq=10)
  32.  
  33. pickle.dump(TEXT, open(f'{PATH}models/TEXT.pkl','wb'))
  34.  
  35. len(md.trn_dl), md.nt, len(md.trn_ds), len(md.trn_ds[0].text)
  36.  
  37. em_sz = 200  # size of each embedding vector
  38. nh = 500     # number of hidden activations per layer
  39. nl = 3       # number of layers
  40.  
  41. opt_fn = partial(optim.Adam, betas=(0.7, 0.99))
  42.  
  43. learner = md.get_model(opt_fn, em_sz, nh, nl,
  44.                dropouti=0.05, dropout=0.05, wdrop=0.1, dropoute=0.02, dropouth=0.05)
  45. learner.reg_fn = partial(seq2seq_reg, alpha=2, beta=1)
  46. learner.clip=0.3
  47.  
  48. # it's overfitting a bit by the later epochs :/
  49. learner.fit(3e-3, 4, wds=1e-6, cycle_len=1, cycle_mult=2)
  50.  
  51. # save model and pickle vocab
  52. learner.save_encoder('austen_adam_enc')
  53. pickle.dump(TEXT, open(f'{PATH}models/TEXT.pkl','wb'))
  54.  
  55.  
  56. # now time to try to make it write stuff
  57. learner.load_encoder('austen_adam_enc')
  58.  
  59. from numpy.random import choice
  60.  
  61. def generate_token(txt):
  62.     m = learner.model
  63.     s = [spacy_tok(txt)]
  64.     t=TEXT.numericalize(s)
  65.     # Set batch size to 1
  66.     m[0].bs=1
  67.     # Turn off dropout
  68.     m.eval()
  69.     # Reset hidden state
  70.     m.reset()
  71.     # Get predictions from model
  72.     res,*_ = m(t)
  73.  
  74.     nexts = torch.topk(res[-1], 10)
  75.  
  76.     unnormalized_probs = to_np(nexts[0])
  77.     sum_probs = sum(unnormalized_probs)
  78.     probs = [p / sum_probs for p in unnormalized_probs]
  79.  
  80.     draw = choice(to_np(nexts[1]), 2, p=probs)
  81.  
  82.     d = draw[0] if draw[0] != 0 else draw[1]
  83.     w = TEXT.vocab.itos[d]
  84.     # reset batch size
  85.     m[0].bs=bs
  86.  
  87.     return w
  88.  
  89. s = "elizabeth"
  90. for i in range(0, 100):
  91.     tok = generate_token(s)
  92.     s += ' ' + tok
  93.     if tok == '.':
  94.         break
  95.  
  96. s
RAW Paste Data
We use cookies for various purposes including analytics. By continuing to use Pastebin, you agree to our use of cookies as described in the Cookies Policy. OK, I Understand
 
Top