SHARE
TWEET

Untitled

a guest Feb 11th, 2019 86 Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
  1. from itertools import chain
  2. from pathlib import Path
  3.  
  4. import tensorflow as tf
  5. from deeppavlov import build_model
  6. from deeppavlov.core.commands.train import train_evaluate_model_from_config
  7. from deeppavlov.core.common.registry import register
  8. from deeppavlov.core.data.data_learning_iterator import DataLearningIterator
  9. from deeppavlov.core.data.dataset_reader import DatasetReader
  10. from deeppavlov.core.data.simple_vocab import SimpleVocabulary
  11. from deeppavlov.core.data.utils import download_decompress
  12. from deeppavlov.core.models.component import Component
  13. from deeppavlov.core.models.tf_model import TFModel
  14. from deeppavlov.models.tokenizers.lazy_tokenizer import LazyTokenizer
  15.  
  16. #download_decompress('http://files.deeppavlov.ai/datasets/personachat_v2.tar.gz', './personachat')
  17.  
  18. @register('personachat_dataset_reader')
  19. class PersonaChatDatasetReader(DatasetReader):
  20.     """
  21.     PersonaChat dataset from
  22.     Zhang S. et al. Personalizing Dialogue Agents: I have a dog, do you have pets too?
  23.     https://arxiv.org/abs/1801.07243
  24.     Also, this dataset is used in ConvAI2 http://convai.io/
  25.     This class reads dataset to the following format:
  26.     [{
  27.         'persona': [list of persona sentences],
  28.         'x': input utterance,
  29.         'y': output utterance,
  30.         'dialog_history': list of previous utterances
  31.         'candidates': [list of candidate utterances]
  32.         'y_idx': index of y utt in candidates list
  33.       },
  34.        ...
  35.     ]
  36.     """
  37.  
  38.     def read(self, dir_path: str, mode='none_original'):
  39.         dir_path = Path(dir_path)
  40.         dataset = {}
  41.         for dt in ['train', 'valid', 'test']:
  42.             dataset[dt] = self._parse_data(dir_path / '{}_{}.txt'.format(dt, mode))
  43.  
  44.         return dataset
  45.  
  46.     @staticmethod
  47.     def _parse_data(filename):
  48.         examples = []
  49.         print(filename)
  50.         curr_persona = []
  51.         curr_dialog_history = []
  52.         persona_done = False
  53.         with filename.open('r') as fin:
  54.             for line in fin:
  55.                 line = ' '.join(line.strip().split(' ')[1:])
  56.                 your_persona_pref = 'your persona: '
  57.                 if line[:len(your_persona_pref)] == your_persona_pref and persona_done:
  58.                     curr_persona = [line[len(your_persona_pref):]]
  59.                     curr_dialog_history = []
  60.                     persona_done = False
  61.                 elif line[:len(your_persona_pref)] == your_persona_pref:
  62.                     curr_persona.append(line[len(your_persona_pref):])
  63.                 else:
  64.                     persona_done = True
  65.                     x, y, _, candidates = line.split('\t')
  66.                     candidates = candidates.split('|')
  67.                     example = {
  68.                         'persona': curr_persona,
  69.                         'x': x,
  70.                         'y': y,
  71.                         'dialog_history': curr_dialog_history[:],
  72.                         'candidates': candidates,
  73.                         'y_idx': candidates.index(y)
  74.                     }
  75.                     curr_dialog_history.extend([x, y])
  76.                     examples.append(example)
  77.  
  78.         return examples
  79.  
  80.  
  81. data = PersonaChatDatasetReader().read('./personachat')
  82.  
  83. for k in data:
  84.     print(k, len(data[k]))
  85.  
  86. print(data['train'][0])
  87.  
  88.  
  89. @register('personachat_iterator')
  90. class PersonaChatIterator(DataLearningIterator):
  91.     def split(self, *args, **kwargs):
  92.         for dt in ['train', 'valid', 'test']:
  93.             setattr(self, dt, self._to_tuple(getattr(self, dt)))
  94.  
  95.     @staticmethod
  96.     def _to_tuple(data):
  97.         """
  98.         Returns:
  99.             list of (x, y)
  100.         """
  101.         return list(map(lambda x: (x['x'], x['y']), data))
  102.  
  103.  
  104. iterator = PersonaChatIterator(data)
  105.  
  106. batch = [el for el in iterator.gen_batches(5, 'train')][0]
  107. for x, y in zip(*batch):
  108.     print('x:', x)
  109.     print('y:', y)
  110.     print('----------')
  111.  
  112.  
  113. tokenizer = LazyTokenizer()
  114. tokenizer(['Hello my friend'])
  115.  
  116. @register('dialog_vocab')
  117. class DialogVocab(SimpleVocabulary):
  118.     def fit(self, *args):
  119.         tokens = chain(*args)
  120.         super().fit(tokens)
  121.  
  122.     def __call__(self, batch, **kwargs):
  123.         indices_batch = []
  124.         for utt in batch:
  125.             tokens = [self[token] for token in utt]
  126.             indices_batch.append(tokens)
  127.         return indices_batch
  128.  
  129.  
  130. vocab = DialogVocab(
  131.     save_path='./vocab.dict',
  132.     load_path='./vocab.dict',
  133.     min_freq=2,
  134.     special_tokens=('<PAD>', '<BOS>', '<EOS>', '<UNK>',),
  135.     unk_token='<UNK>'
  136. )
  137.  
  138. vocab.fit(tokenizer(iterator.get_instances(data_type='train')[0]),
  139.           tokenizer(iterator.get_instances(data_type='train')[1]))
  140. vocab.save()
  141.  
  142. vocab.freqs.most_common(10)
  143.  
  144. len(vocab)
  145.  
  146. vocab([['<BOS>', 'hello', 'my', 'friend', 'there_is_no_such_word_in_dataset', 'and_this', '<EOS>', '<PAD>']])
  147.  
  148.  
  149. @register('sentence_padder')
  150. class SentencePadder(Component):
  151.     def __init__(self, length_limit, pad_token_id=0, start_token_id=1, end_token_id=2, *args, **kwargs):
  152.         self.length_limit = length_limit
  153.         self.pad_token_id = pad_token_id
  154.         self.start_token_id = start_token_id
  155.         self.end_token_id = end_token_id
  156.  
  157.     def __call__(self, batch):
  158.         for i in range(len(batch)):
  159.             batch[i] = batch[i][:self.length_limit]
  160.             batch[i] = [self.start_token_id] + batch[i] + [self.end_token_id]
  161.             batch[i] += [self.pad_token_id] * (self.length_limit + 2 - len(batch[i]))
  162.         return batch
  163.  
  164.  
  165. padder = SentencePadder(length_limit=6)
  166.  
  167. vocab(padder(vocab([['hello', 'my', 'friend', 'there_is_no_such_word_in_dataset', 'and_this']])))
  168.  
  169.  
  170. def encoder(inputs, inputs_len, embedding_matrix, cell_size, keep_prob=1.0):
  171.     # inputs: tf.int32 tensor with shape bs x seq_len with token ids
  172.     # inputs_len: tf.int32 tensor with shape bs
  173.     # embedding_matrix: tf.float32 tensor with shape vocab_size x vocab_dim
  174.     # cell_size: hidden size of recurrent cell
  175.     # keep_prob: dropout keep probability
  176.     with tf.variable_scope('encoder'):
  177.         # first of all we should embed every token in input sequence (use tf.nn.embedding_lookup, don't forget about dropout)
  178.         x_emb = tf.nn.dropout(tf.nn.embedding_lookup(embedding_matrix, inputs), keep_prob=keep_prob)
  179.  
  180.         # define recurrent cell (LSTM or GRU)
  181.         encoder_cell = tf.nn.rnn_cell.GRUCell(
  182.             num_units=cell_size,
  183.             kernel_initializer=tf.contrib.layers.xavier_initializer(),
  184.             name='encoder_cell')
  185.  
  186.         # use tf.nn.dynamic_rnn to encode input sequence, use actual length of input sequence
  187.         encoder_outputs, encoder_state = tf.nn.dynamic_rnn(cell=encoder_cell, inputs=x_emb, sequence_length=inputs_len,
  188.                                                            dtype=tf.float32)
  189.     return encoder_outputs, encoder_state
  190.  
  191.  
  192. tf.reset_default_graph()
  193. vocab_size = 100
  194. hidden_dim = 100
  195. inputs = tf.cast(tf.random_uniform(shape=[32, 10]) * vocab_size, tf.int32)  # bs x seq_len
  196. mask = tf.cast(tf.random_uniform(shape=[32, 10]) * 2, tf.int32)  # bs x seq_len
  197. inputs_len = tf.reduce_sum(mask, axis=1)
  198. embedding_matrix = tf.random_uniform(shape=[vocab_size, hidden_dim])
  199.  
  200. encoder(inputs, inputs_len, embedding_matrix, hidden_dim)
  201.  
  202.  
  203. def decoder(encoder_outputs, encoder_state, embedding_matrix, mask,
  204.             cell_size, max_length, y_ph,
  205.             start_token_id=1, keep_prob=1.0,
  206.             teacher_forcing_rate_ph=None,
  207.             use_attention=False, is_train=True):
  208.     # decoder
  209.     # encoder_outputs: tf.float32 tensor with shape bs x seq_len x encoder_cell_size
  210.     # encoder_state: tf.float32 tensor with shape bs x encoder_cell_size
  211.     # embedding_matrix: tf.float32 tensor with shape vocab_size x vocab_dim
  212.     # mask: tf.int32 tensor with shape bs x seq_len with zeros for masked sequence elements
  213.     # cell_size: hidden size of recurrent cell
  214.     # max_length: max length of output sequence
  215.     # start_token_id: id of <BOS> token in vocabulary
  216.     # keep_prob: dropout keep probability
  217.     # teacher_forcing_rate_ph: rate of using teacher forcing on each decoding step
  218.     # use_attention: use attention on encoder outputs or use only encoder_state
  219.     # is_train: is it training or inference? at inference time we can't use teacher forcing
  220.     with tf.variable_scope('decoder'):
  221.         # define decoder recurrent cell
  222.         decoder_cell = tf.nn.rnn_cell.GRUCell(
  223.             num_units=cell_size,
  224.             kernel_initializer=tf.contrib.layers.xavier_initializer(),
  225.             name='decoder_cell')
  226.  
  227.         # initial value of output_token on previsous step is start_token
  228.         output_token = tf.ones(shape=(tf.shape(encoder_outputs)[0],), dtype=tf.int32) * start_token_id
  229.         # let's define initial value of decoder state with encoder_state
  230.         decoder_state = encoder_state
  231.  
  232.         pred_tokens = []
  233.         logits = []
  234.  
  235.         # use for loop to sequentially call recurrent cell
  236.         for i in range(max_length):
  237.             """
  238.             TEACHER FORCING
  239.             # here you can try to implement teacher forcing for your model
  240.             # details about teacher forcing are explained further in tutorial
  241.  
  242.             # pseudo code:
  243.             NOTE THAT FOLLOWING CONDITIONS SHOULD BE EVALUATED AT GRAPH RUNTIME
  244.             use tf.cond and tf.logical operations instead of python if
  245.  
  246.             if i > 0 and is_train and random_value < teacher_forcing_rate_ph:
  247.                 input_token = y_ph[:, i-1]
  248.             else:
  249.                 input_token = output_token
  250.  
  251.             input_token_emb = tf.nn.embedding_lookup(embedding_matrix, input_token)
  252.  
  253.             """
  254.             if i > 0:
  255.                 input_token_emb = tf.cond(
  256.                     tf.logical_and(
  257.                         is_train,
  258.                         tf.random_uniform(shape=(), maxval=1) <= teacher_forcing_rate_ph
  259.                     ),
  260.                     lambda: tf.nn.embedding_lookup(embedding_matrix, y_ph[:, i - 1]),  # teacher forcing
  261.                     lambda: tf.nn.embedding_lookup(embedding_matrix, output_token)
  262.                 )
  263.             else:
  264.                 input_token_emb = tf.nn.embedding_lookup(embedding_matrix, output_token)
  265.  
  266.             """
  267.             ATTENTION MECHANISM
  268.             # here you can add attention to your model
  269.             # you can find details about attention further in tutorial
  270.             """
  271.             if use_attention:
  272.                 # compute attention and concat attention vector to input_token_emb
  273.                 att = dot_attention(encoder_outputs, decoder_state, mask, scope='att')
  274.                 input_token_emb = tf.concat([input_token_emb, att], axis=-1)
  275.  
  276.             input_token_emb = tf.nn.dropout(input_token_emb, keep_prob=keep_prob)
  277.             # call recurrent cell
  278.             decoder_outputs, decoder_state = decoder_cell(input_token_emb, decoder_state)
  279.             decoder_outputs = tf.nn.dropout(decoder_outputs, keep_prob=keep_prob)
  280.             # project decoder output to embeddings dimension
  281.             embeddings_dim = embedding_matrix.get_shape()[1]
  282.             output_proj = tf.layers.dense(decoder_outputs, embeddings_dim, activation=tf.nn.tanh,
  283.                                           kernel_initializer=tf.contrib.layers.xavier_initializer(),
  284.                                           name='proj', reuse=tf.AUTO_REUSE)
  285.             # compute logits
  286.             output_logits = tf.matmul(output_proj, embedding_matrix, transpose_b=True)
  287.  
  288.             logits.append(output_logits)
  289.             output_probs = tf.nn.softmax(output_logits)
  290.             output_token = tf.argmax(output_probs, axis=-1)
  291.             pred_tokens.append(output_token)
  292.  
  293.         y_pred_tokens = tf.transpose(tf.stack(pred_tokens, axis=0), [1, 0])
  294.         y_logits = tf.transpose(tf.stack(logits, axis=0), [1, 0, 2])
  295.     return y_pred_tokens, y_logits
  296.  
  297.  
  298. tf.reset_default_graph()
  299. vocab_size = 100
  300. hidden_dim = 100
  301. inputs = tf.cast(tf.random_uniform(shape=[32, 10]) * vocab_size, tf.int32)  # bs x seq_len
  302. mask = tf.cast(tf.random_uniform(shape=[32, 10]) * 2, tf.int32)  # bs x seq_len
  303. inputs_len = tf.reduce_sum(mask, axis=1)
  304. embedding_matrix = tf.random_uniform(shape=[vocab_size, hidden_dim])
  305.  
  306. teacher_forcing_rate = tf.random_uniform(shape=())
  307. y = tf.cast(tf.random_uniform(shape=[32, 10]) * vocab_size, tf.int32)
  308.  
  309. encoder_outputs, encoder_state = encoder(inputs, inputs_len, embedding_matrix, hidden_dim)
  310. decoder(encoder_outputs, encoder_state, embedding_matrix, mask, hidden_dim, max_length=10,
  311.         y_ph=y, teacher_forcing_rate_ph=teacher_forcing_rate)
  312.  
  313.  
  314.  
  315. @register('seq2seq')
  316. class Seq2Seq(TFModel):
  317.     def __init__(self, **kwargs):
  318.         # hyperparameters
  319.  
  320.         # dimension of word embeddings
  321.         self.embeddings_dim = kwargs.get('embeddings_dim', 100)
  322.         # size of recurrent cell in encoder and decoder
  323.         self.cell_size = kwargs.get('cell_size', 200)
  324.         # dropout keep_probability
  325.         self.keep_prob = kwargs.get('keep_prob', 0.8)
  326.         # learning rate
  327.         self.learning_rate = kwargs.get('learning_rate', 3e-04)
  328.         # max length of output sequence
  329.         self.max_length = kwargs.get('max_length', 20)
  330.         self.grad_clip = kwargs.get('grad_clip', 5.0)
  331.         self.start_token_id = kwargs.get('start_token_id', 1)
  332.         self.vocab_size = kwargs.get('vocab_size', 11595)
  333.         self.teacher_forcing_rate = kwargs.get('teacher_forcing_rate', 0.0)
  334.         self.use_attention = kwargs.get('use_attention', False)
  335.  
  336.         # create tensorflow session to run computational graph in it
  337.         self.sess_config = tf.ConfigProto(allow_soft_placement=True)
  338.         self.sess_config.gpu_options.allow_growth = True
  339.         self.sess = tf.Session(config=self.sess_config)
  340.  
  341.         self.init_graph()
  342.  
  343.         # define train op
  344.         self.train_op = self.get_train_op(self.loss, self.lr_ph,
  345.                                           optimizer=tf.train.AdamOptimizer,
  346.                                           clip_norm=self.grad_clip)
  347.         # initialize graph variables
  348.         self.sess.run(tf.global_variables_initializer())
  349.  
  350.         super().__init__(**kwargs)
  351.         # load saved model if there is one
  352.         if self.load_path is not None:
  353.             self.load()
  354.  
  355.     def init_graph(self):
  356.         # create placeholders
  357.         self.init_placeholders()
  358.  
  359.         self.x_mask = tf.cast(self.x_ph, tf.int32)
  360.         self.y_mask = tf.cast(self.y_ph, tf.int32)
  361.  
  362.         self.x_len = tf.reduce_sum(self.x_mask, axis=1)
  363.  
  364.         # create embeddings matrix for tokens
  365.         self.embeddings = tf.Variable(
  366.             tf.random_uniform((self.vocab_size, self.embeddings_dim), -0.1, 0.1, name='embeddings'), dtype=tf.float32)
  367.  
  368.         # encoder
  369.         encoder_outputs, encoder_state = encoder(self.x_ph, self.x_len, self.embeddings, self.cell_size,
  370.                                                  self.keep_prob_ph)
  371.  
  372.         # decoder
  373.         self.y_pred_tokens, y_logits = decoder(encoder_outputs, encoder_state, self.embeddings, self.x_mask,
  374.                                                self.cell_size, self.max_length,
  375.                                                self.y_ph, self.start_token_id, self.keep_prob_ph,
  376.                                                self.teacher_forcing_rate_ph, self.use_attention, self.is_train_ph)
  377.  
  378.         # loss
  379.         self.y_ohe = tf.one_hot(self.y_ph, depth=self.vocab_size)
  380.         self.y_mask = tf.cast(self.y_mask, tf.float32)
  381.         self.loss = tf.nn.softmax_cross_entropy_with_logits(labels=self.y_ohe, logits=y_logits) * self.y_mask
  382.         self.loss = tf.reduce_sum(self.loss) / tf.reduce_sum(self.y_mask)
  383.  
  384.     def init_placeholders(self):
  385.         # placeholders for inputs
  386.         self.x_ph = tf.placeholder(shape=(None, None), dtype=tf.int32, name='x_ph')
  387.         # at inference time y_ph is used (y_ph exists in computational graph)  when teacher forcing is activated, so we add dummy default value
  388.         # this dummy value is not actually used at inference
  389.         self.y_ph = tf.placeholder_with_default(tf.zeros_like(self.x_ph), shape=(None, None), name='y_ph')
  390.  
  391.         # placeholders for model parameters
  392.         self.lr_ph = tf.placeholder(dtype=tf.float32, shape=[], name='lr_ph')
  393.         self.keep_prob_ph = tf.placeholder_with_default(1.0, shape=[], name='keep_prob_ph')
  394.         self.is_train_ph = tf.placeholder_with_default(False, shape=[], name='is_train_ph')
  395.         self.teacher_forcing_rate_ph = tf.placeholder_with_default(0.0, shape=[], name='teacher_forcing_rate_ph')
  396.  
  397.     def _build_feed_dict(self, x, y=None):
  398.         feed_dict = {
  399.             self.x_ph: x,
  400.         }
  401.         if y is not None:
  402.             feed_dict.update({
  403.                 self.y_ph: y,
  404.                 self.lr_ph: self.learning_rate,
  405.                 self.keep_prob_ph: self.keep_prob,
  406.                 self.is_train_ph: True,
  407.                 self.teacher_forcing_rate_ph: self.teacher_forcing_rate,
  408.             })
  409.         return feed_dict
  410.  
  411.     def train_on_batch(self, x, y):
  412.         feed_dict = self._build_feed_dict(x, y)
  413.         loss, _ = self.sess.run([self.loss, self.train_op], feed_dict=feed_dict)
  414.         return loss
  415.  
  416.     def __call__(self, x):
  417.         feed_dict = self._build_feed_dict(x)
  418.         y_pred = self.sess.run(self.y_pred_tokens, feed_dict=feed_dict)
  419.         return y_pred
  420.  
  421.  
  422. s2s = Seq2Seq(
  423.     save_path='./save/seq2seq_model',
  424.     load_path='./save/seq2seq_model'
  425. )
  426.  
  427. vocab(s2s(padder(vocab([['hello', 'my', 'friend', 'there_is_no_such_word_in_dataset', 'and_this']]))))
  428.  
  429.  
  430. def softmax_mask(values, mask):
  431.     # adds big negative to masked values
  432.     INF = 1e30
  433.     return -INF * (1 - tf.cast(mask, tf.float32)) + values
  434.  
  435.  
  436. def dot_attention(memory, state, mask, scope="dot_attention"):
  437.     # inputs: bs x seq_len x hidden_dim
  438.     # state: bs x hidden_dim
  439.     # mask: bs x seq_len
  440.     with tf.variable_scope(scope, reuse=tf.AUTO_REUSE):
  441.         # dot product between each item in memory and state
  442.         logits = tf.matmul(memory, tf.expand_dims(state, axis=1), transpose_b=True)
  443.         logits = tf.squeeze(logits, [2])
  444.  
  445.         # apply mask to logits
  446.         logits = softmax_mask(logits, mask)
  447.  
  448.         # apply softmax to logits
  449.         att_weights = tf.expand_dims(tf.nn.softmax(logits), axis=2)
  450.  
  451.         # compute weighted sum of items in memory
  452.         att = tf.reduce_sum(att_weights * memory, axis=1)
  453.         return att
  454.  
  455.  
  456. tf.reset_default_graph()
  457. memory = tf.random_normal(shape=[32, 10, 100])  # bs x seq_len x hidden_dim
  458. state = tf.random_normal(shape=[32, 100])  # bs x hidden_dim
  459. mask = tf.cast(tf.random_normal(shape=[32, 10]), tf.int32)  # bs x seq_len
  460. dot_attention(memory, state, mask)
  461.  
  462.  
  463. @register('postprocessing')
  464. class SentencePostprocessor(Component):
  465.     def __init__(self, pad_token='<PAD>', start_token='<BOS>', end_token='<EOS>', *args, **kwargs):
  466.         self.pad_token = pad_token
  467.         self.start_token = start_token
  468.         self.end_token = end_token
  469.  
  470.     def __call__(self, batch):
  471.         for i in range(len(batch)):
  472.             batch[i] = ' '.join(self._postproc(batch[i]))
  473.         return batch
  474.  
  475.     def _postproc(self, utt):
  476.         if self.end_token in utt:
  477.             utt = utt[:utt.index(self.end_token)]
  478.         return utt
  479.  
  480.  
  481. postprocess = SentencePostprocessor()
  482.  
  483. postprocess(vocab(s2s(padder(vocab([['hello', 'my', 'friend', 'there_is_no_such_word_in_dataset', 'and_this']])))))
  484.  
  485. config = {
  486.     "dataset_reader": {
  487.         "class_name": "personachat_dataset_reader",
  488.         "data_path": "./personachat"
  489.     },
  490.     "dataset_iterator": {
  491.         "class_name": "personachat_iterator",
  492.         "seed": 1337,
  493.         "shuffle": True
  494.     },
  495.     "chainer": {
  496.         "in": ["x"],
  497.         "in_y": ["y"],
  498.         "pipe": [
  499.             {
  500.                 "class_name": "lazy_tokenizer",
  501.                 "id": "tokenizer",
  502.                 "in": ["x"],
  503.                 "out": ["x_tokens"]
  504.             },
  505.             {
  506.                 "class_name": "lazy_tokenizer",
  507.                 "id": "tokenizer",
  508.                 "in": ["y"],
  509.                 "out": ["y_tokens"]
  510.             },
  511.             {
  512.                 "class_name": "dialog_vocab",
  513.                 "id": "vocab",
  514.                 "save_path": "./vocab.dict",
  515.                 "load_path": "./vocab.dict",
  516.                 "min_freq": 2,
  517.                 "special_tokens": ["<PAD>", "<BOS>", "<EOS>", "<UNK>"],
  518.                 "unk_token": "<UNK>",
  519.                 "fit_on": ["x_tokens", "y_tokens"],
  520.                 "in": ["x_tokens"],
  521.                 "out": ["x_tokens_ids"]
  522.             },
  523.             {
  524.                 "ref": "vocab",
  525.                 "in": ["y_tokens"],
  526.                 "out": ["y_tokens_ids"]
  527.             },
  528.             {
  529.                 "class_name": "sentence_padder",
  530.                 "id": "padder",
  531.                 "length_limit": 20,
  532.                 "in": ["x_tokens_ids"],
  533.                 "out": ["x_tokens_ids"]
  534.             },
  535.             {
  536.                 "ref": "padder",
  537.                 "in": ["y_tokens_ids"],
  538.                 "out": ["y_tokens_ids"]
  539.             },
  540.             {
  541.                 "class_name": "seq2seq",
  542.                 "id": "s2s",
  543.                 "max_length": "#padder.length_limit+2",
  544.                 "cell_size": 250,
  545.                 "embeddings_dim": 50,
  546.                 "vocab_size": 11595,
  547.                 "keep_prob": 0.8,
  548.                 "learning_rate": 3e-04,
  549.                 "teacher_forcing_rate": 0.0,
  550.                 "use_attention": False,
  551.                 "save_path": "./save/seq2seq_model",
  552.                 "load_path": "./save/seq2seq_model",
  553.                 "in": ["x_tokens_ids"],
  554.                 "in_y": ["y_tokens_ids"],
  555.                 "out": ["y_predicted_tokens_ids"],
  556.             },
  557.             {
  558.                 "ref": "vocab",
  559.                 "in": ["y_predicted_tokens_ids"],
  560.                 "out": ["y_predicted_tokens"]
  561.             },
  562.             {
  563.                 "class_name": "postprocessing",
  564.                 "in": ["y_predicted_tokens"],
  565.                 "out": ["y_predicted_tokens"]
  566.             }
  567.         ],
  568.         "out": ["y_predicted_tokens"]
  569.     },
  570.     "train": {
  571.         "log_every_n_batches": 100,
  572.         "val_every_n_epochs": 0,
  573.         "batch_size": 64,
  574.         "validation_patience": 0,
  575.         "epochs": 20,
  576.         "metrics": ["bleu"],
  577.     }
  578. }
  579.  
  580.  
  581. model = build_model(config)
  582. model(['Hi, how are you?', 'Any ideas my dear friend?'])
  583. train_evaluate_model_from_config(config=config)
  584. model = build_model(config)
  585. model(['hi, how are you?', 'any ideas my dear friend?', 'okay, i agree with you', 'good bye!'])
RAW Paste Data
We use cookies for various purposes including analytics. By continuing to use Pastebin, you agree to our use of cookies as described in the Cookies Policy. OK, I Understand
 
Top