Advertisement
Guest User

Untitled

a guest
Feb 11th, 2019
147
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
PyCon 22.40 KB | None | 0 0
  1. from itertools import chain
  2. from pathlib import Path
  3.  
  4. import tensorflow as tf
  5. from deeppavlov import build_model
  6. from deeppavlov.core.commands.train import train_evaluate_model_from_config
  7. from deeppavlov.core.common.registry import register
  8. from deeppavlov.core.data.data_learning_iterator import DataLearningIterator
  9. from deeppavlov.core.data.dataset_reader import DatasetReader
  10. from deeppavlov.core.data.simple_vocab import SimpleVocabulary
  11. from deeppavlov.core.data.utils import download_decompress
  12. from deeppavlov.core.models.component import Component
  13. from deeppavlov.core.models.tf_model import TFModel
  14. from deeppavlov.models.tokenizers.lazy_tokenizer import LazyTokenizer
  15.  
  16. #download_decompress('http://files.deeppavlov.ai/datasets/personachat_v2.tar.gz', './personachat')
  17.  
  18. @register('personachat_dataset_reader')
  19. class PersonaChatDatasetReader(DatasetReader):
  20.     """
  21.     PersonaChat dataset from
  22.     Zhang S. et al. Personalizing Dialogue Agents: I have a dog, do you have pets too?
  23.     https://arxiv.org/abs/1801.07243
  24.     Also, this dataset is used in ConvAI2 http://convai.io/
  25.     This class reads dataset to the following format:
  26.     [{
  27.         'persona': [list of persona sentences],
  28.         'x': input utterance,
  29.         'y': output utterance,
  30.         'dialog_history': list of previous utterances
  31.         'candidates': [list of candidate utterances]
  32.         'y_idx': index of y utt in candidates list
  33.       },
  34.        ...
  35.     ]
  36.     """
  37.  
  38.     def read(self, dir_path: str, mode='none_original'):
  39.         dir_path = Path(dir_path)
  40.         dataset = {}
  41.         for dt in ['train', 'valid', 'test']:
  42.             dataset[dt] = self._parse_data(dir_path / '{}_{}.txt'.format(dt, mode))
  43.  
  44.         return dataset
  45.  
  46.     @staticmethod
  47.     def _parse_data(filename):
  48.         examples = []
  49.         print(filename)
  50.         curr_persona = []
  51.         curr_dialog_history = []
  52.         persona_done = False
  53.         with filename.open('r') as fin:
  54.             for line in fin:
  55.                 line = ' '.join(line.strip().split(' ')[1:])
  56.                 your_persona_pref = 'your persona: '
  57.                 if line[:len(your_persona_pref)] == your_persona_pref and persona_done:
  58.                     curr_persona = [line[len(your_persona_pref):]]
  59.                     curr_dialog_history = []
  60.                     persona_done = False
  61.                 elif line[:len(your_persona_pref)] == your_persona_pref:
  62.                     curr_persona.append(line[len(your_persona_pref):])
  63.                 else:
  64.                     persona_done = True
  65.                     x, y, _, candidates = line.split('\t')
  66.                     candidates = candidates.split('|')
  67.                     example = {
  68.                         'persona': curr_persona,
  69.                         'x': x,
  70.                         'y': y,
  71.                         'dialog_history': curr_dialog_history[:],
  72.                         'candidates': candidates,
  73.                         'y_idx': candidates.index(y)
  74.                     }
  75.                     curr_dialog_history.extend([x, y])
  76.                     examples.append(example)
  77.  
  78.         return examples
  79.  
  80.  
  81. data = PersonaChatDatasetReader().read('./personachat')
  82.  
  83. for k in data:
  84.     print(k, len(data[k]))
  85.  
  86. print(data['train'][0])
  87.  
  88.  
  89. @register('personachat_iterator')
  90. class PersonaChatIterator(DataLearningIterator):
  91.     def split(self, *args, **kwargs):
  92.         for dt in ['train', 'valid', 'test']:
  93.             setattr(self, dt, self._to_tuple(getattr(self, dt)))
  94.  
  95.     @staticmethod
  96.     def _to_tuple(data):
  97.         """
  98.         Returns:
  99.             list of (x, y)
  100.         """
  101.         return list(map(lambda x: (x['x'], x['y']), data))
  102.  
  103.  
  104. iterator = PersonaChatIterator(data)
  105.  
  106. batch = [el for el in iterator.gen_batches(5, 'train')][0]
  107. for x, y in zip(*batch):
  108.     print('x:', x)
  109.     print('y:', y)
  110.     print('----------')
  111.  
  112.  
  113. tokenizer = LazyTokenizer()
  114. tokenizer(['Hello my friend'])
  115.  
  116. @register('dialog_vocab')
  117. class DialogVocab(SimpleVocabulary):
  118.     def fit(self, *args):
  119.         tokens = chain(*args)
  120.         super().fit(tokens)
  121.  
  122.     def __call__(self, batch, **kwargs):
  123.         indices_batch = []
  124.         for utt in batch:
  125.             tokens = [self[token] for token in utt]
  126.             indices_batch.append(tokens)
  127.         return indices_batch
  128.  
  129.  
  130. vocab = DialogVocab(
  131.     save_path='./vocab.dict',
  132.     load_path='./vocab.dict',
  133.     min_freq=2,
  134.     special_tokens=('<PAD>', '<BOS>', '<EOS>', '<UNK>',),
  135.     unk_token='<UNK>'
  136. )
  137.  
  138. vocab.fit(tokenizer(iterator.get_instances(data_type='train')[0]),
  139.           tokenizer(iterator.get_instances(data_type='train')[1]))
  140. vocab.save()
  141.  
  142. vocab.freqs.most_common(10)
  143.  
  144. len(vocab)
  145.  
  146. vocab([['<BOS>', 'hello', 'my', 'friend', 'there_is_no_such_word_in_dataset', 'and_this', '<EOS>', '<PAD>']])
  147.  
  148.  
  149. @register('sentence_padder')
  150. class SentencePadder(Component):
  151.     def __init__(self, length_limit, pad_token_id=0, start_token_id=1, end_token_id=2, *args, **kwargs):
  152.         self.length_limit = length_limit
  153.         self.pad_token_id = pad_token_id
  154.         self.start_token_id = start_token_id
  155.         self.end_token_id = end_token_id
  156.  
  157.     def __call__(self, batch):
  158.         for i in range(len(batch)):
  159.             batch[i] = batch[i][:self.length_limit]
  160.             batch[i] = [self.start_token_id] + batch[i] + [self.end_token_id]
  161.             batch[i] += [self.pad_token_id] * (self.length_limit + 2 - len(batch[i]))
  162.         return batch
  163.  
  164.  
  165. padder = SentencePadder(length_limit=6)
  166.  
  167. vocab(padder(vocab([['hello', 'my', 'friend', 'there_is_no_such_word_in_dataset', 'and_this']])))
  168.  
  169.  
  170. def encoder(inputs, inputs_len, embedding_matrix, cell_size, keep_prob=1.0):
  171.     # inputs: tf.int32 tensor with shape bs x seq_len with token ids
  172.     # inputs_len: tf.int32 tensor with shape bs
  173.     # embedding_matrix: tf.float32 tensor with shape vocab_size x vocab_dim
  174.     # cell_size: hidden size of recurrent cell
  175.     # keep_prob: dropout keep probability
  176.     with tf.variable_scope('encoder'):
  177.         # first of all we should embed every token in input sequence (use tf.nn.embedding_lookup, don't forget about dropout)
  178.         x_emb = tf.nn.dropout(tf.nn.embedding_lookup(embedding_matrix, inputs), keep_prob=keep_prob)
  179.  
  180.         # define recurrent cell (LSTM or GRU)
  181.         encoder_cell = tf.nn.rnn_cell.GRUCell(
  182.             num_units=cell_size,
  183.             kernel_initializer=tf.contrib.layers.xavier_initializer(),
  184.             name='encoder_cell')
  185.  
  186.         # use tf.nn.dynamic_rnn to encode input sequence, use actual length of input sequence
  187.         encoder_outputs, encoder_state = tf.nn.dynamic_rnn(cell=encoder_cell, inputs=x_emb, sequence_length=inputs_len,
  188.                                                            dtype=tf.float32)
  189.     return encoder_outputs, encoder_state
  190.  
  191.  
  192. tf.reset_default_graph()
  193. vocab_size = 100
  194. hidden_dim = 100
  195. inputs = tf.cast(tf.random_uniform(shape=[32, 10]) * vocab_size, tf.int32)  # bs x seq_len
  196. mask = tf.cast(tf.random_uniform(shape=[32, 10]) * 2, tf.int32)  # bs x seq_len
  197. inputs_len = tf.reduce_sum(mask, axis=1)
  198. embedding_matrix = tf.random_uniform(shape=[vocab_size, hidden_dim])
  199.  
  200. encoder(inputs, inputs_len, embedding_matrix, hidden_dim)
  201.  
  202.  
  203. def decoder(encoder_outputs, encoder_state, embedding_matrix, mask,
  204.             cell_size, max_length, y_ph,
  205.             start_token_id=1, keep_prob=1.0,
  206.             teacher_forcing_rate_ph=None,
  207.             use_attention=False, is_train=True):
  208.     # decoder
  209.     # encoder_outputs: tf.float32 tensor with shape bs x seq_len x encoder_cell_size
  210.     # encoder_state: tf.float32 tensor with shape bs x encoder_cell_size
  211.     # embedding_matrix: tf.float32 tensor with shape vocab_size x vocab_dim
  212.     # mask: tf.int32 tensor with shape bs x seq_len with zeros for masked sequence elements
  213.     # cell_size: hidden size of recurrent cell
  214.     # max_length: max length of output sequence
  215.     # start_token_id: id of <BOS> token in vocabulary
  216.     # keep_prob: dropout keep probability
  217.     # teacher_forcing_rate_ph: rate of using teacher forcing on each decoding step
  218.     # use_attention: use attention on encoder outputs or use only encoder_state
  219.     # is_train: is it training or inference? at inference time we can't use teacher forcing
  220.     with tf.variable_scope('decoder'):
  221.         # define decoder recurrent cell
  222.         decoder_cell = tf.nn.rnn_cell.GRUCell(
  223.             num_units=cell_size,
  224.             kernel_initializer=tf.contrib.layers.xavier_initializer(),
  225.             name='decoder_cell')
  226.  
  227.         # initial value of output_token on previsous step is start_token
  228.         output_token = tf.ones(shape=(tf.shape(encoder_outputs)[0],), dtype=tf.int32) * start_token_id
  229.         # let's define initial value of decoder state with encoder_state
  230.         decoder_state = encoder_state
  231.  
  232.         pred_tokens = []
  233.         logits = []
  234.  
  235.         # use for loop to sequentially call recurrent cell
  236.         for i in range(max_length):
  237.             """
  238.             TEACHER FORCING
  239.             # here you can try to implement teacher forcing for your model
  240.             # details about teacher forcing are explained further in tutorial
  241.  
  242.             # pseudo code:
  243.             NOTE THAT FOLLOWING CONDITIONS SHOULD BE EVALUATED AT GRAPH RUNTIME
  244.             use tf.cond and tf.logical operations instead of python if
  245.  
  246.             if i > 0 and is_train and random_value < teacher_forcing_rate_ph:
  247.                 input_token = y_ph[:, i-1]
  248.             else:
  249.                 input_token = output_token
  250.  
  251.             input_token_emb = tf.nn.embedding_lookup(embedding_matrix, input_token)
  252.  
  253.             """
  254.             if i > 0:
  255.                 input_token_emb = tf.cond(
  256.                     tf.logical_and(
  257.                         is_train,
  258.                         tf.random_uniform(shape=(), maxval=1) <= teacher_forcing_rate_ph
  259.                     ),
  260.                     lambda: tf.nn.embedding_lookup(embedding_matrix, y_ph[:, i - 1]),  # teacher forcing
  261.                     lambda: tf.nn.embedding_lookup(embedding_matrix, output_token)
  262.                 )
  263.             else:
  264.                 input_token_emb = tf.nn.embedding_lookup(embedding_matrix, output_token)
  265.  
  266.             """
  267.             ATTENTION MECHANISM
  268.             # here you can add attention to your model
  269.             # you can find details about attention further in tutorial
  270.             """
  271.             if use_attention:
  272.                 # compute attention and concat attention vector to input_token_emb
  273.                 att = dot_attention(encoder_outputs, decoder_state, mask, scope='att')
  274.                 input_token_emb = tf.concat([input_token_emb, att], axis=-1)
  275.  
  276.             input_token_emb = tf.nn.dropout(input_token_emb, keep_prob=keep_prob)
  277.             # call recurrent cell
  278.             decoder_outputs, decoder_state = decoder_cell(input_token_emb, decoder_state)
  279.             decoder_outputs = tf.nn.dropout(decoder_outputs, keep_prob=keep_prob)
  280.             # project decoder output to embeddings dimension
  281.             embeddings_dim = embedding_matrix.get_shape()[1]
  282.             output_proj = tf.layers.dense(decoder_outputs, embeddings_dim, activation=tf.nn.tanh,
  283.                                           kernel_initializer=tf.contrib.layers.xavier_initializer(),
  284.                                           name='proj', reuse=tf.AUTO_REUSE)
  285.             # compute logits
  286.             output_logits = tf.matmul(output_proj, embedding_matrix, transpose_b=True)
  287.  
  288.             logits.append(output_logits)
  289.             output_probs = tf.nn.softmax(output_logits)
  290.             output_token = tf.argmax(output_probs, axis=-1)
  291.             pred_tokens.append(output_token)
  292.  
  293.         y_pred_tokens = tf.transpose(tf.stack(pred_tokens, axis=0), [1, 0])
  294.         y_logits = tf.transpose(tf.stack(logits, axis=0), [1, 0, 2])
  295.     return y_pred_tokens, y_logits
  296.  
  297.  
  298. tf.reset_default_graph()
  299. vocab_size = 100
  300. hidden_dim = 100
  301. inputs = tf.cast(tf.random_uniform(shape=[32, 10]) * vocab_size, tf.int32)  # bs x seq_len
  302. mask = tf.cast(tf.random_uniform(shape=[32, 10]) * 2, tf.int32)  # bs x seq_len
  303. inputs_len = tf.reduce_sum(mask, axis=1)
  304. embedding_matrix = tf.random_uniform(shape=[vocab_size, hidden_dim])
  305.  
  306. teacher_forcing_rate = tf.random_uniform(shape=())
  307. y = tf.cast(tf.random_uniform(shape=[32, 10]) * vocab_size, tf.int32)
  308.  
  309. encoder_outputs, encoder_state = encoder(inputs, inputs_len, embedding_matrix, hidden_dim)
  310. decoder(encoder_outputs, encoder_state, embedding_matrix, mask, hidden_dim, max_length=10,
  311.         y_ph=y, teacher_forcing_rate_ph=teacher_forcing_rate)
  312.  
  313.  
  314.  
  315. @register('seq2seq')
  316. class Seq2Seq(TFModel):
  317.     def __init__(self, **kwargs):
  318.         # hyperparameters
  319.  
  320.         # dimension of word embeddings
  321.         self.embeddings_dim = kwargs.get('embeddings_dim', 100)
  322.         # size of recurrent cell in encoder and decoder
  323.         self.cell_size = kwargs.get('cell_size', 200)
  324.         # dropout keep_probability
  325.         self.keep_prob = kwargs.get('keep_prob', 0.8)
  326.         # learning rate
  327.         self.learning_rate = kwargs.get('learning_rate', 3e-04)
  328.         # max length of output sequence
  329.         self.max_length = kwargs.get('max_length', 20)
  330.         self.grad_clip = kwargs.get('grad_clip', 5.0)
  331.         self.start_token_id = kwargs.get('start_token_id', 1)
  332.         self.vocab_size = kwargs.get('vocab_size', 11595)
  333.         self.teacher_forcing_rate = kwargs.get('teacher_forcing_rate', 0.0)
  334.         self.use_attention = kwargs.get('use_attention', False)
  335.  
  336.         # create tensorflow session to run computational graph in it
  337.         self.sess_config = tf.ConfigProto(allow_soft_placement=True)
  338.         self.sess_config.gpu_options.allow_growth = True
  339.         self.sess = tf.Session(config=self.sess_config)
  340.  
  341.         self.init_graph()
  342.  
  343.         # define train op
  344.         self.train_op = self.get_train_op(self.loss, self.lr_ph,
  345.                                           optimizer=tf.train.AdamOptimizer,
  346.                                           clip_norm=self.grad_clip)
  347.         # initialize graph variables
  348.         self.sess.run(tf.global_variables_initializer())
  349.  
  350.         super().__init__(**kwargs)
  351.         # load saved model if there is one
  352.         if self.load_path is not None:
  353.             self.load()
  354.  
  355.     def init_graph(self):
  356.         # create placeholders
  357.         self.init_placeholders()
  358.  
  359.         self.x_mask = tf.cast(self.x_ph, tf.int32)
  360.         self.y_mask = tf.cast(self.y_ph, tf.int32)
  361.  
  362.         self.x_len = tf.reduce_sum(self.x_mask, axis=1)
  363.  
  364.         # create embeddings matrix for tokens
  365.         self.embeddings = tf.Variable(
  366.             tf.random_uniform((self.vocab_size, self.embeddings_dim), -0.1, 0.1, name='embeddings'), dtype=tf.float32)
  367.  
  368.         # encoder
  369.         encoder_outputs, encoder_state = encoder(self.x_ph, self.x_len, self.embeddings, self.cell_size,
  370.                                                  self.keep_prob_ph)
  371.  
  372.         # decoder
  373.         self.y_pred_tokens, y_logits = decoder(encoder_outputs, encoder_state, self.embeddings, self.x_mask,
  374.                                                self.cell_size, self.max_length,
  375.                                                self.y_ph, self.start_token_id, self.keep_prob_ph,
  376.                                                self.teacher_forcing_rate_ph, self.use_attention, self.is_train_ph)
  377.  
  378.         # loss
  379.         self.y_ohe = tf.one_hot(self.y_ph, depth=self.vocab_size)
  380.         self.y_mask = tf.cast(self.y_mask, tf.float32)
  381.         self.loss = tf.nn.softmax_cross_entropy_with_logits(labels=self.y_ohe, logits=y_logits) * self.y_mask
  382.         self.loss = tf.reduce_sum(self.loss) / tf.reduce_sum(self.y_mask)
  383.  
  384.     def init_placeholders(self):
  385.         # placeholders for inputs
  386.         self.x_ph = tf.placeholder(shape=(None, None), dtype=tf.int32, name='x_ph')
  387.         # at inference time y_ph is used (y_ph exists in computational graph)  when teacher forcing is activated, so we add dummy default value
  388.         # this dummy value is not actually used at inference
  389.         self.y_ph = tf.placeholder_with_default(tf.zeros_like(self.x_ph), shape=(None, None), name='y_ph')
  390.  
  391.         # placeholders for model parameters
  392.         self.lr_ph = tf.placeholder(dtype=tf.float32, shape=[], name='lr_ph')
  393.         self.keep_prob_ph = tf.placeholder_with_default(1.0, shape=[], name='keep_prob_ph')
  394.         self.is_train_ph = tf.placeholder_with_default(False, shape=[], name='is_train_ph')
  395.         self.teacher_forcing_rate_ph = tf.placeholder_with_default(0.0, shape=[], name='teacher_forcing_rate_ph')
  396.  
  397.     def _build_feed_dict(self, x, y=None):
  398.         feed_dict = {
  399.             self.x_ph: x,
  400.         }
  401.         if y is not None:
  402.             feed_dict.update({
  403.                 self.y_ph: y,
  404.                 self.lr_ph: self.learning_rate,
  405.                 self.keep_prob_ph: self.keep_prob,
  406.                 self.is_train_ph: True,
  407.                 self.teacher_forcing_rate_ph: self.teacher_forcing_rate,
  408.             })
  409.         return feed_dict
  410.  
  411.     def train_on_batch(self, x, y):
  412.         feed_dict = self._build_feed_dict(x, y)
  413.         loss, _ = self.sess.run([self.loss, self.train_op], feed_dict=feed_dict)
  414.         return loss
  415.  
  416.     def __call__(self, x):
  417.         feed_dict = self._build_feed_dict(x)
  418.         y_pred = self.sess.run(self.y_pred_tokens, feed_dict=feed_dict)
  419.         return y_pred
  420.  
  421.  
  422. s2s = Seq2Seq(
  423.     save_path='./save/seq2seq_model',
  424.     load_path='./save/seq2seq_model'
  425. )
  426.  
  427. vocab(s2s(padder(vocab([['hello', 'my', 'friend', 'there_is_no_such_word_in_dataset', 'and_this']]))))
  428.  
  429.  
  430. def softmax_mask(values, mask):
  431.     # adds big negative to masked values
  432.     INF = 1e30
  433.     return -INF * (1 - tf.cast(mask, tf.float32)) + values
  434.  
  435.  
  436. def dot_attention(memory, state, mask, scope="dot_attention"):
  437.     # inputs: bs x seq_len x hidden_dim
  438.     # state: bs x hidden_dim
  439.     # mask: bs x seq_len
  440.     with tf.variable_scope(scope, reuse=tf.AUTO_REUSE):
  441.         # dot product between each item in memory and state
  442.         logits = tf.matmul(memory, tf.expand_dims(state, axis=1), transpose_b=True)
  443.         logits = tf.squeeze(logits, [2])
  444.  
  445.         # apply mask to logits
  446.         logits = softmax_mask(logits, mask)
  447.  
  448.         # apply softmax to logits
  449.         att_weights = tf.expand_dims(tf.nn.softmax(logits), axis=2)
  450.  
  451.         # compute weighted sum of items in memory
  452.         att = tf.reduce_sum(att_weights * memory, axis=1)
  453.         return att
  454.  
  455.  
  456. tf.reset_default_graph()
  457. memory = tf.random_normal(shape=[32, 10, 100])  # bs x seq_len x hidden_dim
  458. state = tf.random_normal(shape=[32, 100])  # bs x hidden_dim
  459. mask = tf.cast(tf.random_normal(shape=[32, 10]), tf.int32)  # bs x seq_len
  460. dot_attention(memory, state, mask)
  461.  
  462.  
  463. @register('postprocessing')
  464. class SentencePostprocessor(Component):
  465.     def __init__(self, pad_token='<PAD>', start_token='<BOS>', end_token='<EOS>', *args, **kwargs):
  466.         self.pad_token = pad_token
  467.         self.start_token = start_token
  468.         self.end_token = end_token
  469.  
  470.     def __call__(self, batch):
  471.         for i in range(len(batch)):
  472.             batch[i] = ' '.join(self._postproc(batch[i]))
  473.         return batch
  474.  
  475.     def _postproc(self, utt):
  476.         if self.end_token in utt:
  477.             utt = utt[:utt.index(self.end_token)]
  478.         return utt
  479.  
  480.  
  481. postprocess = SentencePostprocessor()
  482.  
  483. postprocess(vocab(s2s(padder(vocab([['hello', 'my', 'friend', 'there_is_no_such_word_in_dataset', 'and_this']])))))
  484.  
  485. config = {
  486.     "dataset_reader": {
  487.         "class_name": "personachat_dataset_reader",
  488.         "data_path": "./personachat"
  489.     },
  490.     "dataset_iterator": {
  491.         "class_name": "personachat_iterator",
  492.         "seed": 1337,
  493.         "shuffle": True
  494.     },
  495.     "chainer": {
  496.         "in": ["x"],
  497.         "in_y": ["y"],
  498.         "pipe": [
  499.             {
  500.                 "class_name": "lazy_tokenizer",
  501.                 "id": "tokenizer",
  502.                 "in": ["x"],
  503.                 "out": ["x_tokens"]
  504.             },
  505.             {
  506.                 "class_name": "lazy_tokenizer",
  507.                 "id": "tokenizer",
  508.                 "in": ["y"],
  509.                 "out": ["y_tokens"]
  510.             },
  511.             {
  512.                 "class_name": "dialog_vocab",
  513.                 "id": "vocab",
  514.                 "save_path": "./vocab.dict",
  515.                 "load_path": "./vocab.dict",
  516.                 "min_freq": 2,
  517.                 "special_tokens": ["<PAD>", "<BOS>", "<EOS>", "<UNK>"],
  518.                 "unk_token": "<UNK>",
  519.                 "fit_on": ["x_tokens", "y_tokens"],
  520.                 "in": ["x_tokens"],
  521.                 "out": ["x_tokens_ids"]
  522.             },
  523.             {
  524.                 "ref": "vocab",
  525.                 "in": ["y_tokens"],
  526.                 "out": ["y_tokens_ids"]
  527.             },
  528.             {
  529.                 "class_name": "sentence_padder",
  530.                 "id": "padder",
  531.                 "length_limit": 20,
  532.                 "in": ["x_tokens_ids"],
  533.                 "out": ["x_tokens_ids"]
  534.             },
  535.             {
  536.                 "ref": "padder",
  537.                 "in": ["y_tokens_ids"],
  538.                 "out": ["y_tokens_ids"]
  539.             },
  540.             {
  541.                 "class_name": "seq2seq",
  542.                 "id": "s2s",
  543.                 "max_length": "#padder.length_limit+2",
  544.                 "cell_size": 250,
  545.                 "embeddings_dim": 50,
  546.                 "vocab_size": 11595,
  547.                 "keep_prob": 0.8,
  548.                 "learning_rate": 3e-04,
  549.                 "teacher_forcing_rate": 0.0,
  550.                 "use_attention": False,
  551.                 "save_path": "./save/seq2seq_model",
  552.                 "load_path": "./save/seq2seq_model",
  553.                 "in": ["x_tokens_ids"],
  554.                 "in_y": ["y_tokens_ids"],
  555.                 "out": ["y_predicted_tokens_ids"],
  556.             },
  557.             {
  558.                 "ref": "vocab",
  559.                 "in": ["y_predicted_tokens_ids"],
  560.                 "out": ["y_predicted_tokens"]
  561.             },
  562.             {
  563.                 "class_name": "postprocessing",
  564.                 "in": ["y_predicted_tokens"],
  565.                 "out": ["y_predicted_tokens"]
  566.             }
  567.         ],
  568.         "out": ["y_predicted_tokens"]
  569.     },
  570.     "train": {
  571.         "log_every_n_batches": 100,
  572.         "val_every_n_epochs": 0,
  573.         "batch_size": 64,
  574.         "validation_patience": 0,
  575.         "epochs": 20,
  576.         "metrics": ["bleu"],
  577.     }
  578. }
  579.  
  580.  
  581. model = build_model(config)
  582. model(['Hi, how are you?', 'Any ideas my dear friend?'])
  583. train_evaluate_model_from_config(config=config)
  584. model = build_model(config)
  585. model(['hi, how are you?', 'any ideas my dear friend?', 'okay, i agree with you', 'good bye!'])
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement