Advertisement
Guest User

Untitled

a guest
Mar 19th, 2019
75
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
text 10.38 KB | None | 0 0
  1. #!/usr/bin/env python3
  2. # -*- coding: utf-8 -*-
  3.  
  4. """
  5. CS224N 2018-19: Homework 5
  6. vocab.py: Vocabulary Generation
  7. Pencheng Yin <pcyin@cs.cmu.edu>
  8. Sahil Chopra <schopra8@stanford.edu>
  9.  
  10. Usage:
  11. vocab.py --train-src=<file> --train-tgt=<file> [options] VOCAB_FILE
  12.  
  13. Options:
  14. -h --help Show this screen.
  15. --train-src=<file> File of training source sentences
  16. --train-tgt=<file> File of training target sentences
  17. --size=<int> vocab size [default: 50000]
  18. --freq-cutoff=<int> frequency cutoff [default: 2]
  19. """
  20.  
  21. from collections import Counter
  22. from docopt import docopt
  23. from itertools import chain
  24. import json
  25. import torch
  26. from typing import List
  27. from utils import read_corpus, pad_sents, pad_sents_char
  28.  
  29. class VocabEntry(object):
  30. """ Vocabulary Entry, i.e. structure containing either
  31. src or tgt language terms.
  32. """
  33. def __init__(self, word2id=None):
  34. """ Init VocabEntry Instance.
  35. @param word2id (dict): dictionary mapping words 2 indices
  36. """
  37. if word2id:
  38. self.word2id = word2id
  39. else:
  40. self.word2id = dict()
  41. self.word2id['<pad>'] = 0 # Pad Token
  42. self.word2id['<s>'] = 1 # Start Token
  43. self.word2id['</s>'] = 2 # End Token
  44. self.word2id['<unk>'] = 3 # Unknown Token
  45. self.unk_id = self.word2id['<unk>']
  46. self.id2word = {v: k for k, v in self.word2id.items()}
  47.  
  48. ## Additions to the A4 code:
  49. self.char_list = list("""ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789,;.!?:'\"/\\|_@#$%^&*~`+-=<>()[]""")
  50.  
  51. self.char2id = dict() # Converts characters to integers
  52. self.char2id['<pad>'] = 0
  53. self.char2id['{'] = 1
  54. self.char2id['}'] = 2
  55. self.char2id['<unk>'] = 3
  56. for i, c in enumerate(self.char_list):
  57. self.char2id[c] = len(self.char2id)
  58. self.char_unk = self.char2id['<unk>']
  59. self.start_of_word = self.char2id["{"]
  60. self.end_of_word = self.char2id["}"]
  61. assert self.start_of_word+1 == self.end_of_word
  62.  
  63. self.id2char = {v: k for k, v in self.char2id.items()} # Converts integers to characters
  64. ## End additions to the A4 code
  65.  
  66. def __getitem__(self, word):
  67. """ Retrieve word's index. Return the index for the unk
  68. token if the word is out of vocabulary.
  69. @param word (str): word to look up.
  70. @returns index (int): index of word
  71. """
  72. return self.word2id.get(word, self.unk_id)
  73.  
  74. def __contains__(self, word):
  75. """ Check if word is captured by VocabEntry.
  76. @param word (str): word to look up
  77. @returns contains (bool): whether word is contained
  78. """
  79. return word in self.word2id
  80.  
  81. def __setitem__(self, key, value):
  82. """ Raise error, if one tries to edit the VocabEntry.
  83. """
  84. raise ValueError('vocabulary is readonly')
  85.  
  86. def __len__(self):
  87. """ Compute number of words in VocabEntry.
  88. @returns len (int): number of words in VocabEntry
  89. """
  90. return len(self.word2id)
  91.  
  92. def __repr__(self):
  93. """ Representation of VocabEntry to be used
  94. when printing the object.
  95. """
  96. return 'Vocabulary[size=%d]' % len(self)
  97.  
  98. def id2word(self, wid):
  99. """ Return mapping of index to word.
  100. @param wid (int): word index
  101. @returns word (str): word corresponding to index
  102. """
  103. return self.id2word[wid]
  104.  
  105. def add(self, word):
  106. """ Add word to VocabEntry, if it is previously unseen.
  107. @param word (str): word to add to VocabEntry
  108. @return index (int): index that the word has been assigned
  109. """
  110. if word not in self:
  111. wid = self.word2id[word] = len(self)
  112. self.id2word[wid] = word
  113. return wid
  114. else:
  115. return self[word]
  116.  
  117. def words2charindices(self, sents):
  118. """ Convert list of sentences of words into list of list of list of character indices.
  119. @param sents (list[list[str]]): sentence(s) in words
  120. @return word_ids (list[list[list[int]]]): sentence(s) in indices
  121. """
  122. ### YOUR CODE HERE for part 1e
  123. ### TODO:
  124. ### This method should convert characters in the input sentences into their
  125. ### corresponding character indices using the character vocabulary char2id
  126. ### defined above.
  127. ###
  128. ### You must prepend each word with the `start_of_word` character and append
  129. ### with the `end_of_word` character.
  130.  
  131.  
  132.  
  133. return [[[self.start_of_word] + [self.char2id[c] for c in w] + [self.end_of_word] for w in s] for s in sents]
  134.  
  135.  
  136. ### END YOUR CODE
  137.  
  138. def words2indices(self, sents):
  139. """ Convert list of sentences of words into list of list of indices.
  140. @param sents (list[list[str]]): sentence(s) in words
  141. @return word_ids (list[list[int]]): sentence(s) in indices
  142. """
  143. return [[self[w] for w in s] for s in sents]
  144.  
  145. def indices2words(self, word_ids):
  146. """ Convert list of indices into words.
  147. @param word_ids (list[int]): list of word ids
  148. @return sents (list[str]): list of words
  149. """
  150. return [self.id2word[w_id] for w_id in word_ids]
  151.  
  152. def to_input_tensor_char(self, sents: List[List[str]]) -> torch.Tensor:
  153. """ Convert list of sentences (words) into tensor with necessary padding for
  154. shorter sentences.
  155.  
  156. @param sents (List[List[str]]): list of sentences (words)
  157. @param device: device on which to load the tensor, i.e. CPU or GPU
  158.  
  159. @returns sents_var: tensor of (max_sentence_length, batch_size, max_word_length)
  160. """
  161. ### YOUR CODE HERE for part 1g
  162. ### TODO:
  163. ### Connect `words2charindices()` and `pad_sents_char()` which you've defined in
  164. ### previous parts
  165. max_sentence_length = max(len(s) for s in sents)
  166. batch_size = len(sents)
  167. max_word_length = 21
  168.  
  169. # print(sents)
  170.  
  171. char_ids = self.words2charindices(sents)
  172. sents_t = pad_sents_char(char_ids, self.char2id['<pad>'])
  173. sents_var = torch.LongTensor(sents_t)
  174. sents_var = sents_var.permute(1,0,2).contiguous()
  175.  
  176. return sents_var
  177.  
  178. ### END YOUR CODE
  179.  
  180. def to_input_tensor(self, sents: List[List[str]]) -> torch.Tensor:
  181. """ Convert list of sentences (words) into tensor with necessary padding for
  182. shorter sentences.
  183.  
  184. @param sents (List[List[str]]): list of sentences (words)
  185. @param device: device on which to load the tesnor, i.e. CPU or GPU
  186.  
  187. @returns sents_var: tensor of (max_sentence_length, batch_size)
  188. """
  189. word_ids = self.words2indices(sents)
  190. sents_t = pad_sents(word_ids, self['<pad>'])
  191. sents_var = torch.LongTensor(sents_t)
  192. return torch.t(sents_var)
  193.  
  194. @staticmethod
  195. def from_corpus(corpus, size, freq_cutoff=2):
  196. """ Given a corpus construct a Vocab Entry.
  197. @param corpus (list[str]): corpus of text produced by read_corpus function
  198. @param size (int): # of words in vocabulary
  199. @param freq_cutoff (int): if word occurs n < freq_cutoff times, drop the word
  200. @returns vocab_entry (VocabEntry): VocabEntry instance produced from provided corpus
  201. """
  202. vocab_entry = VocabEntry()
  203. word_freq = Counter(chain(*corpus))
  204. valid_words = [w for w, v in word_freq.items() if v >= freq_cutoff]
  205. print('number of word types: {}, number of word types w/ frequency >= {}: {}'
  206. .format(len(word_freq), freq_cutoff, len(valid_words)))
  207. top_k_words = sorted(valid_words, key=lambda w: word_freq[w], reverse=True)[:size]
  208. for word in top_k_words:
  209. vocab_entry.add(word)
  210. return vocab_entry
  211.  
  212.  
  213. class Vocab(object):
  214. """ Vocab encapsulating src and target langauges.
  215. """
  216. def __init__(self, src_vocab: VocabEntry, tgt_vocab: VocabEntry):
  217. """ Init Vocab.
  218. @param src_vocab (VocabEntry): VocabEntry for source language
  219. @param tgt_vocab (VocabEntry): VocabEntry for target language
  220. """
  221. self.src = src_vocab
  222. self.tgt = tgt_vocab
  223.  
  224. @staticmethod
  225. def build(src_sents, tgt_sents, vocab_size, freq_cutoff) -> 'Vocab':
  226. """ Build Vocabulary.
  227. @param src_sents (list[str]): Source sentences provided by read_corpus() function
  228. @param tgt_sents (list[str]): Target sentences provided by read_corpus() function
  229. @param vocab_size (int): Size of vocabulary for both source and target languages
  230. @param freq_cutoff (int): if word occurs n < freq_cutoff times, drop the word.
  231. """
  232. # assert len(src_sents) == len(tgt_sents)
  233.  
  234. print('initialize source vocabulary ..')
  235. src = VocabEntry.from_corpus(src_sents, vocab_size, freq_cutoff)
  236.  
  237. print('initialize target vocabulary .. (I made it just the source repeated)')
  238. tgt = VocabEntry.from_corpus(src_sents, vocab_size, freq_cutoff)
  239.  
  240. return Vocab(src, tgt)
  241.  
  242. def save(self, file_path):
  243. """ Save Vocab to file as JSON dump.
  244. @param file_path (str): file path to vocab file
  245. """
  246. json.dump(dict(src_word2id=self.src.word2id, tgt_word2id=self.tgt.word2id), open(file_path, 'w'), indent=2)
  247.  
  248. @staticmethod
  249. def load(file_path):
  250. """ Load vocabulary from JSON dump.
  251. @param file_path (str): file path to vocab file
  252. @returns Vocab object loaded from JSON dump
  253. """
  254. entry = json.load(open(file_path, 'r'))
  255. src_word2id = entry['src_word2id']
  256. tgt_word2id = entry['tgt_word2id']
  257.  
  258. return Vocab(VocabEntry(src_word2id), VocabEntry(tgt_word2id))
  259.  
  260. def __repr__(self):
  261. """ Representation of Vocab to be used
  262. when printing the object.
  263. """
  264. return 'Vocab(source %d words, target %d words)' % (len(self.src), len(self.tgt))
  265.  
  266.  
  267.  
  268. if __name__ == '__main__':
  269. args = docopt(__doc__)
  270.  
  271. print('read in source sentences: %s' % args['--train-src'])
  272. print('read in target sentences: %s' % args['--train-tgt'])
  273.  
  274. src_sents = read_corpus(args['--train-src'], source='src')
  275. tgt_sents = read_corpus(args['--train-tgt'], source='tgt')
  276.  
  277. vocab = Vocab.build(src_sents, tgt_sents, int(args['--size']), int(args['--freq-cutoff']))
  278. print('generated vocabulary, source %d words, target %d words' % (len(vocab.src), len(vocab.tgt)))
  279.  
  280. vocab.save(args['VOCAB_FILE'])
  281. print('vocabulary saved to %s' % args['VOCAB_FILE'])
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement