Advertisement
Guest User

Untitled

a guest
Apr 24th, 2018
92
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
Python 11.35 KB | None | 0 0
  1. #!/usr/bin/env python3
  2. import numpy as np
  3. import tensorflow as tf
  4.  
  5. import morpho_dataset
  6.  
  7. class Network:
  8.     def __init__(self, threads, seed=42):
  9.         # Create an empty graph and a session
  10.         graph = tf.Graph()
  11.         graph.seed = seed
  12.         self.session = tf.Session(graph = graph, config=tf.ConfigProto(inter_op_parallelism_threads=threads,
  13.                                                                        intra_op_parallelism_threads=threads))
  14.  
  15.     def construct(self, args, num_words, num_chars, num_tags):
  16.         with self.session.graph.as_default():
  17.             if args.recodex:
  18.                 tf.get_variable_scope().set_initializer(tf.glorot_uniform_initializer(seed=42))
  19.  
  20.             # Inputs
  21.             self.sentence_lens = tf.placeholder(tf.int32, [None], name="sentence_lens")
  22.             self.word_ids = tf.placeholder(tf.int32, [None, None], name="word_ids")
  23.             self.charseqs = tf.placeholder(tf.int32, [None, None], name="charseqs")
  24.             self.charseq_lens = tf.placeholder(tf.int32, [None], name="charseq_lens")
  25.             self.charseq_ids = tf.placeholder(tf.int32, [None, None], name="charseq_ids")
  26.             self.tags = tf.placeholder(tf.int32, [None, None], name="tags")
  27.  
  28.             # TODO(we): Choose RNN cell class according to args.rnn_cell (LSTM and GRU
  29.             # should be supported, using tf.nn.rnn_cell.{BasicLSTM,GRU}Cell).
  30.             rnn_cell_fw = None
  31.             rnn_cell_bw = None
  32.             if args.rnn_cell == "LSTM":
  33.                 rnn_cell_fw = tf.nn.rnn_cell.BasicLSTMCell(args.rnn_cell_dim)
  34.                 rnn_cell_bw = tf.nn.rnn_cell.BasicLSTMCell(args.rnn_cell_dim)
  35.             elif args.rnn_cell == "GRU":
  36.                 rnn_cell_fw = tf.nn.rnn_cell.GRUCell(args.rnn_cell_dim)
  37.                 rnn_cell_bw = tf.nn.rnn_cell.GRUCell(args.rnn_cell_dim)
  38.             else:
  39.                 print("FU")
  40.  
  41.             # TODO(we): Create word embeddings for num_words of dimensionality args.we_dim
  42.             # using `tf.get_variable`.
  43.             word_e = tf.get_variable("embed", shape=[num_words, args.we_dim])
  44.  
  45.             # TODO(we): Embed self.word_ids according to the word embeddings, by utilizing
  46.             # `tf.nn.embedding_lookup`.
  47.             print("word_e ",word_e)
  48.             embd_words = tf.nn.embedding_lookup(word_e, self.word_ids)
  49.             print(embd_words)
  50.  
  51.             # Convolutional word embeddings (CNNE)
  52.  
  53.             # TODO: Generate character embeddings for num_chars of dimensionality args.cle_dim.
  54.             char_e = tf.get_variable("char", shape=[num_chars, args.cle_dim])
  55.  
  56.             # TODO: Embed self.charseqs (list of unique words in the batch) using the character embeddings.
  57.             embd_chars = tf.nn.embedding_lookup(char_e, self.charseqs)
  58.             print("chare ",char_e)
  59.  
  60.             # TODO: For kernel sizes of {2..args.cnne_max}, do the following:
  61.             # - use `tf.layers.conv1d` on input embedded characters, with given kernel size
  62.             #   and `args.cnne_filters`; use `VALID` padding, stride 1 and no activation.
  63.             # - perform channel-wise max-pooling over the whole word, generating output
  64.             #   of size `args.cnne_filters` for every word.
  65.             print(embd_chars)
  66.  
  67.             features = []
  68.             for kernel_size in range(2, args.cnne_max+1, 1):
  69.                 before_max_pool = tf.layers.conv1d(embd_chars, args.cnne_filters, kernel_size, strides=1, padding='valid', activation=None)
  70.                 print("Before maxpool ", before_max_pool)
  71.                 max_pooled = tf.layers.max_pooling1d(before_max_pool, args.cnne_filters, 2)
  72.                 features.append(max_pooled)
  73.                 print("Max pooled ", max_pooled)
  74.  
  75.             # TODO: Concatenate the computed features (in the order of kernel sizes 2..args.cnne_max).
  76.             # Consequently, each word from `self.charseqs` is represented using convolutional embedding
  77.             # (CNNE) of size `(args.cnne_max-1)*args.cnne_filters`.
  78.             cnne_embed = tf.concat(features, 2)
  79.             print("---")
  80.             print(cnne_embed)
  81.  
  82.             # TODO: Generate CNNEs of all words in the batch by indexing the just computed embeddings
  83.             # by self.charseq_ids (using tf.nn.embedding_lookup).
  84.             all_cnnes = tf.nn.embedding_lookup(cnne_embed, self.charseq_ids)
  85.             print(all_cnnes)
  86.  
  87.             # TODO: Concatenate the word embeddings (computed above) and the CNNE (in this order).
  88.             concatenated = tf.concat([embd_words, all_cnnes], 2)
  89.  
  90.             # TODO(we): Using tf.nn.bidirectional_dynamic_rnn, process the embedded inputs.
  91.             # Use given rnn_cell (different for fwd and bwd direction) and self.sentence_lens.
  92.             outputs, _ = tf.nn.bidirectional_dynamic_rnn(rnn_cell_fw,
  93.                                                         rnn_cell_bw,
  94.                                                         concatenated,
  95.                                                         dtype=tf.float32,
  96.                                                         sequence_length=self.sentence_lens)
  97.             out_fw, out_bw = outputs
  98.  
  99.             # TODO(we): Concatenate the outputs for fwd and bwd directions (in the third dimension).
  100.             concat_output = tf.concat([out_fw, out_bw], 2)
  101.  
  102.             # TODO(we): Add a dense layer (without activation) into num_tags classes and
  103.             # store result in `output_layer`.
  104.             output_layer = tf.layers.dense(concat_output, num_tags, activation=None)
  105.  
  106.             # TODO(we): Generate `self.predictions`.
  107.             self.predictions = tf.argmax(output_layer, axis=2)
  108.  
  109.             # TODO(we): Generate `weights` as a 1./0. mask of valid/invalid words (using `tf.sequence_mask`).
  110.             weights = tf.sequence_mask(self.sentence_lens, dtype=tf.float32)
  111.  
  112.             # Training
  113.  
  114.             # TODO(we): Define `loss` using `tf.losses.sparse_softmax_cross_entropy`, but additionally
  115.             # use `weights` parameter to mask-out invalid words.
  116.             loss = tf.losses.sparse_softmax_cross_entropy(self.tags, output_layer, weights=weights, scope="loss")
  117.  
  118.             global_step = tf.train.create_global_step()
  119.             self.training = tf.train.AdamOptimizer().minimize(loss, global_step=global_step, name="training")
  120.  
  121.             # Summaries
  122.             self.current_accuracy, self.update_accuracy = tf.metrics.accuracy(self.tags, self.predictions, weights=weights)
  123.             self.current_loss, self.update_loss = tf.metrics.mean(loss, weights=tf.reduce_sum(weights))
  124.             self.reset_metrics = tf.variables_initializer(tf.get_collection(tf.GraphKeys.METRIC_VARIABLES))
  125.  
  126.             summary_writer = tf.contrib.summary.create_file_writer(args.logdir, flush_millis=10 * 1000)
  127.             self.summaries = {}
  128.             with summary_writer.as_default(), tf.contrib.summary.record_summaries_every_n_global_steps(10):
  129.                 self.summaries["train"] = [tf.contrib.summary.scalar("train/loss", self.update_loss),
  130.                                            tf.contrib.summary.scalar("train/accuracy", self.update_accuracy)]
  131.             with summary_writer.as_default(), tf.contrib.summary.always_record_summaries():
  132.                 for dataset in ["dev", "test"]:
  133.                     self.summaries[dataset] = [tf.contrib.summary.scalar(dataset + "/loss", self.current_loss),
  134.                                                tf.contrib.summary.scalar(dataset + "/accuracy", self.current_accuracy)]
  135.  
  136.             # Initialize variables
  137.             self.session.run(tf.global_variables_initializer())
  138.             with summary_writer.as_default():
  139.                 tf.contrib.summary.initialize(session=self.session, graph=self.session.graph)
  140.  
  141.     def train_epoch(self, train, batch_size):
  142.         while not train.epoch_finished():
  143.             sentence_lens, word_ids, charseq_ids, charseqs, charseq_lens = train.next_batch(batch_size, including_charseqs=True)
  144.             self.session.run(self.reset_metrics)
  145.             self.session.run([self.training, self.summaries["train"]],
  146.                              {self.sentence_lens: sentence_lens,
  147.                               self.charseqs: charseqs[train.FORMS], self.charseq_lens: charseq_lens[train.FORMS],
  148.                               self.word_ids: word_ids[train.FORMS], self.charseq_ids: charseq_ids[train.FORMS],
  149.                               self.tags: word_ids[train.TAGS]})
  150.  
  151.     def evaluate(self, dataset_name, dataset, batch_size):
  152.         self.session.run(self.reset_metrics)
  153.         while not dataset.epoch_finished():
  154.             sentence_lens, word_ids, charseq_ids, charseqs, charseq_lens = dataset.next_batch(batch_size, including_charseqs=True)
  155.             self.session.run([self.update_accuracy, self.update_loss],
  156.                              {self.sentence_lens: sentence_lens,
  157.                               self.charseqs: charseqs[train.FORMS], self.charseq_lens: charseq_lens[train.FORMS],
  158.                               self.word_ids: word_ids[train.FORMS], self.charseq_ids: charseq_ids[train.FORMS],
  159.                               self.tags: word_ids[train.TAGS]})
  160.         return self.session.run([self.current_accuracy, self.summaries[dataset_name]])[0]
  161.  
  162.  
  163. if __name__ == "__main__":
  164.     import argparse
  165.     import datetime
  166.     import os
  167.     import re
  168.  
  169.     # Fix random seed
  170.     np.random.seed(42)
  171.  
  172.     # Parse arguments
  173.     parser = argparse.ArgumentParser()
  174.     parser.add_argument("--batch_size", default=10, type=int, help="Batch size.")
  175.     parser.add_argument("--cle_dim", default=32, type=int, help="Character-level embedding dimension.")
  176.     parser.add_argument("--cnne_filters", default=16, type=int, help="CNN embedding filters per length.")
  177.     parser.add_argument("--cnne_max", default=4, type=int, help="Maximum CNN filter length.")
  178.     parser.add_argument("--epochs", default=10, type=int, help="Number of epochs.")
  179.     parser.add_argument("--recodex", default=False, action="store_true", help="ReCodEx mode.")
  180.     parser.add_argument("--rnn_cell", default="LSTM", type=str, help="RNN cell type.")
  181.     parser.add_argument("--rnn_cell_dim", default=64, type=int, help="RNN cell dimension.")
  182.     parser.add_argument("--threads", default=1, type=int, help="Maximum number of threads to use.")
  183.     parser.add_argument("--we_dim", default=64, type=int, help="Word embedding dimension.")
  184.     args = parser.parse_args()
  185.  
  186.     # Create logdir name
  187.     args.logdir = "logs/{}-{}-{}".format(
  188.         os.path.basename(__file__),
  189.         datetime.datetime.now().strftime("%Y-%m-%d_%H%M%S"),
  190.         ",".join(("{}={}".format(re.sub("(.)[^_]*_?", r"\1", key), value) for key, value in sorted(vars(args).items())))
  191.     )
  192.     if not os.path.exists("logs"): os.mkdir("logs") # TF 1.6 will do this by itself
  193.  
  194.     # Load the data
  195.     train = morpho_dataset.MorphoDataset("czech-cac-train.txt", max_sentences=5000)
  196.     dev = morpho_dataset.MorphoDataset("czech-cac-dev.txt", train=train, shuffle_batches=False)
  197.  
  198.     # Construct the network
  199.     network = Network(threads=args.threads)
  200.     network.construct(args, len(train.factors[train.FORMS].words), len(train.factors[train.FORMS].alphabet),
  201.                       len(train.factors[train.TAGS].words))
  202.  
  203.     # Train
  204.     for i in range(args.epochs):
  205.         network.train_epoch(train, args.batch_size)
  206.  
  207.         accuracy = network.evaluate("dev", dev, args.batch_size)
  208.         print("{:.2f}".format(100 * accuracy))
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement