Advertisement
Guest User

Untitled

a guest
Oct 18th, 2019
92
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
text 2.27 KB | None | 0 0
  1. THRESHOLD = 10
  2. MAX_LEN = 60
  3. class TextDataset(data.Dataset):
  4. def __init__(self, examples, split, ixtoword=None, wordtoix=None, THRESHOLD=THRESHOLD):
  5. self.examples = examples
  6. self.split = split
  7. self.THRESHOLD = THRESHOLD
  8. self.vocab_size = 0
  9. self.textual_ids = list()
  10. if self.split == "train":
  11. self.ixtoword = dict()
  12. self.wordtoix = dict()
  13. self.build_dictionary()
  14. else:
  15. self.ixtoword = ixtoword
  16. self.wordtoix = wordtoix
  17.  
  18. ### TO-DO
  19.  
  20. def build_dictionary(self):
  21. ### TO-DO
  22. ### <end> should be at idx 0
  23. ### <unk> should be at idx 1
  24. self.ixtoword[0] = "<end>"
  25. self.ixtoword[1] = "<unk>"
  26. self.wordtoix["<end>"] = 0
  27. self.wordtoix["<unk>"] = 1
  28. cur_id = 2
  29. word_counts = dict()
  30. for sentence in self.examples:
  31. for word in sentence.text:
  32. word = word.lower()
  33. if word in word_counts:
  34. word_counts[word] += 1
  35. else:
  36. word_counts[word] = 1
  37. for word, count in word_counts.items():
  38. word = word.lower()
  39. if count >= THRESHOLD:
  40. self.ixtoword[cur_id] = word
  41. self.wordtoix[word] = cur_id
  42. cur_id += 1
  43. else:
  44. continue
  45. for sentence in self.examples:
  46. sen = list()
  47. for word in sentence.text:
  48. word = word.lower
  49. if word in self.wordtoix:
  50. sen.append(self.wordtoix[word])
  51. else:
  52. sen.append(1)
  53. self.textual_ids.append(sen)
  54.  
  55.  
  56. self.vocab_size = cur_id
  57. print(len(self.ixtoword))
  58. print(len(self.wordtoix))
  59. return self.textual_ids, self.ixtoword, self.wordtoix
  60.  
  61. def get_label(self, index):
  62. ### TO-DO
  63. if self.examples[index].label == 'positive':
  64. return 0
  65. else:
  66. return 1
  67.  
  68. def get_text(self, index):
  69. ### TO-DO
  70. while len(self.textual_ids[index]) < MAX_LEN:
  71. self.textual_ids[index].append(0)
  72. #print(self.textual_ids[index])
  73. return torch.LongTensor(self.textual_ids[index])
  74.  
  75. def __len__(self):
  76. ### TO-DO
  77. return len(self.examples)
  78.  
  79. def __getitem__(self, index):
  80. ### TO-DO
  81. text = self.get_text(index)
  82. lbl = self.get_label(index)
  83. text_len = len(self.examples[index].text)
  84.  
  85. return text, text_len, lbl
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement