Not a member of Pastebin yet?
Sign Up,
it unlocks many cool features!
- from fastai.learner import *
- import torchtext
- from torchtext import vocab, data
- from torchtext.datasets import language_modeling
- from fastai.rnn_reg import *
- from fastai.rnn_train import *
- from fastai.nlp import *
- from fastai.lm_rnn import *
- import dill as pickle
- PATH="data/"
- TRN_PATH = 'train/'
- VAL_PATH = 'valid/'
- TST_PATH = 'test/'
- TRN = f'{PATH}{TRN_PATH}'
- VAL = f'{PATH}{VAL_PATH}'
- TST = f'{PATH}{TST_PATH}'
- TEXT = data.Field(lower=True, tokenize=spacy_tok)
- bs=64
- bptt=70
- FILES = dict(train=TRN_PATH, validation=VAL_PATH, test=TST_PATH)
- md = LanguageModelData.from_text_files(PATH, TEXT, **FILES, bs=bs, bptt=bptt, min_freq=10)
- pickle.dump(TEXT, open(f'{PATH}models/TEXT.pkl','wb'))
- len(md.trn_dl), md.nt, len(md.trn_ds), len(md.trn_ds[0].text)
- em_sz = 200 # size of each embedding vector
- nh = 500 # number of hidden activations per layer
- nl = 3 # number of layers
- opt_fn = partial(optim.Adam, betas=(0.7, 0.99))
- learner = md.get_model(opt_fn, em_sz, nh, nl,
- dropouti=0.05, dropout=0.05, wdrop=0.1, dropoute=0.02, dropouth=0.05)
- learner.reg_fn = partial(seq2seq_reg, alpha=2, beta=1)
- learner.clip=0.3
- # it's overfitting a bit by the later epochs :/
- learner.fit(3e-3, 4, wds=1e-6, cycle_len=1, cycle_mult=2)
- # save model and pickle vocab
- learner.save_encoder('austen_adam_enc')
- pickle.dump(TEXT, open(f'{PATH}models/TEXT.pkl','wb'))
- # now time to try to make it write stuff
- learner.load_encoder('austen_adam_enc')
- from numpy.random import choice
- def generate_token(txt):
- m = learner.model
- s = [spacy_tok(txt)]
- t=TEXT.numericalize(s)
- # Set batch size to 1
- m[0].bs=1
- # Turn off dropout
- m.eval()
- # Reset hidden state
- m.reset()
- # Get predictions from model
- res,*_ = m(t)
- nexts = torch.topk(res[-1], 10)
- unnormalized_probs = to_np(nexts[0])
- sum_probs = sum(unnormalized_probs)
- probs = [p / sum_probs for p in unnormalized_probs]
- draw = choice(to_np(nexts[1]), 2, p=probs)
- d = draw[0] if draw[0] != 0 else draw[1]
- w = TEXT.vocab.itos[d]
- # reset batch size
- m[0].bs=bs
- return w
- s = "elizabeth"
- for i in range(0, 100):
- tok = generate_token(s)
- s += ' ' + tok
- if tok == '.':
- break
- s
Add Comment
Please, Sign In to add comment