Advertisement
Guest User

Untitled

a guest
Apr 9th, 2020
177
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
Python 9.26 KB | None | 0 0
  1. from transformers import BertTokenizer
  2. from torch.utils.data import Dataset
  3. import json
  4. import torch
  5. from typing import Tuple, Dict
  6. import numpy as np
  7. import nltk
  8.  
  9.  
  10. class VoxelSentenceMappingRegDataset:
  11.     def __init__(
  12.         self,
  13.         json_path: str,
  14.         bert_tokenizer_path_or_name: str,
  15.         use_all_voxels: bool = False,
  16.         ind2anchors: Dict = None
  17.     ):
  18.         self.json_data = json.load(open(json_path))
  19.         self.sentences, self.mappings, self.keywords, self.organs_indices, self.docs_indices = (
  20.             [],
  21.             [],
  22.             [],
  23.             [],
  24.             [],
  25.         )
  26.  
  27.         for element in self.json_data:
  28.             if len(element["text"]) > 200:
  29.                 continue
  30.             self.sentences.append(element["text"])
  31.             if use_all_voxels and ind2anchors:
  32.                 self.mappings.append(
  33.                     [ind2anchors[ind] for ind in element["organ_indices"]]
  34.                 )
  35.             else:
  36.                 self.mappings.append(element["centers"])
  37.             self.keywords.append(element["keywords"])
  38.             self.organs_indices.append(element["organ_indices"])
  39.             self.docs_indices.append(element["paper_idx"])
  40.         self.tokenizer = BertTokenizer.from_pretrained(bert_tokenizer_path_or_name)
  41.  
  42.  
  43. class VoxelSentenceMappingTrainRegDataset(VoxelSentenceMappingRegDataset, Dataset):
  44.     def __init__(
  45.         self,
  46.         json_path: str,
  47.         bert_tokenizer_path_or_name: str,
  48.         mask_probability: float,
  49.         use_all_voxels: bool = False,
  50.         ind2anchors: Dict = None
  51.     ):
  52.         super().__init__(
  53.             json_path,
  54.             bert_tokenizer_path_or_name,
  55.             use_all_voxels,
  56.             ind2anchors
  57.         )
  58.         self.mask_probability = mask_probability
  59.  
  60.     def __len__(self):
  61.         return len(self.sentences)
  62.  
  63.     def __getitem__(self, idx: int):
  64.         mask = {
  65.             word: np.random.choice(
  66.                 [0, 1], p=[1 - self.mask_probability, self.mask_probability]
  67.             )
  68.             for word in self.keywords[idx]
  69.         }
  70.         masked_sentence = " ".join(
  71.             [
  72.                 "[MASK]" if word in mask and mask[word] == 1 else word
  73.                 for word in nltk.word_tokenize(self.sentences[idx])
  74.             ]
  75.         )
  76.         tokenized_sentence = torch.tensor(
  77.             self.tokenizer.encode(masked_sentence, add_special_tokens=True)
  78.         )
  79.         mapping = torch.tensor(self.mappings[idx])
  80.         organ_indices = torch.tensor(self.organs_indices[idx])
  81.         doc_indices = torch.tensor(self.docs_indices[idx])
  82.         num_organs = len(mapping)
  83.  
  84.         return (tokenized_sentence, mapping, num_organs, organ_indices, doc_indices)
  85.  
  86.  
  87. class VoxelSentenceMappingTestRegDataset(VoxelSentenceMappingRegDataset, Dataset):
  88.     def __init__(
  89.         self,
  90.         json_path: str,
  91.         bert_tokenizer_path_or_name: str,
  92.         use_all_voxels: bool = False,
  93.         ind2anchors: Dict = None
  94.     ):
  95.         super().__init__(
  96.             json_path,
  97.             bert_tokenizer_path_or_name,
  98.             use_all_voxels,
  99.             ind2anchors
  100.         )
  101.  
  102.     def __len__(self):
  103.         return len(self.sentences)
  104.  
  105.     def __getitem__(self, idx: int):
  106.         tokenized_sentence = torch.tensor(
  107.             self.tokenizer.encode(self.sentences[idx], add_special_tokens=True)
  108.         )
  109.         mapping = torch.tensor(self.mappings[idx])
  110.         organ_indices = torch.tensor(self.organs_indices[idx])
  111.         doc_indices = torch.tensor(self.docs_indices[idx])
  112.         num_organs = len(mapping)
  113.  
  114.         return (tokenized_sentence, mapping, num_organs, organ_indices, doc_indices)
  115.  
  116.  
  117. class VoxelSentenceMappingTestMaskedRegDataset(VoxelSentenceMappingRegDataset, Dataset):
  118.     def __init__(self, json_path: str, bert_tokenizer_path_or_name: str, use_all_voxels: bool = False, ind2anchors: Dict = None):
  119.         super().__init__(json_path, bert_tokenizer_path_or_name, use_all_voxels, ind2anchors)
  120.  
  121.     def __len__(self):
  122.         return len(self.sentences)
  123.  
  124.     def __getitem__(self, idx: int):
  125.         mask = {word for word in self.keywords[idx]}
  126.         masked_sentence = " ".join(
  127.             [
  128.                 "[MASK]" if word in mask else word
  129.                 for word in nltk.word_tokenize(self.sentences[idx])
  130.             ]
  131.         )
  132.         tokenized_sentence = torch.tensor(
  133.             self.tokenizer.encode(masked_sentence, add_special_tokens=True)
  134.         )
  135.         mapping = torch.tensor(self.mappings[idx])
  136.         organ_indices = torch.tensor(self.organs_indices[idx])
  137.         doc_indices = torch.tensor(self.docs_indices[idx])
  138.         num_organs = len(mapping)
  139.  
  140.         return (tokenized_sentence, mapping, num_organs, organ_indices, doc_indices)
  141.  
  142.  
  143. def collate_pad_sentence_reg_batch(
  144.     batch: Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]
  145. ):
  146.     sentences, mappings, num_organs, organ_indices, doc_indices = zip(*batch)
  147.     padded_sentences = torch.nn.utils.rnn.pad_sequence(sentences, batch_first=True)
  148.     padded_mappings = torch.nn.utils.rnn.pad_sequence(mappings, batch_first=True)
  149.     num_organs = torch.tensor([*num_organs])
  150.     padded_organ_indices = torch.nn.utils.rnn.pad_sequence(
  151.         organ_indices, batch_first=True, padding_value=-1
  152.     )
  153.  
  154.     return (
  155.         padded_sentences,
  156.         padded_mappings,
  157.         num_organs,
  158.         padded_organ_indices,
  159.         doc_indices,
  160.     )
  161.  
  162.  
  163. class VoxelSentenceMappingClassDataset:
  164.     def __init__(
  165.         self, json_path: str, bert_tokenizer_path_or_name: str, num_classes: int
  166.     ):
  167.         self.json_data = json.load(open(json_path))
  168.         self.sentences, self.organs_indices, self.keywords = [], [], []
  169.         self.num_classes = num_classes
  170.         for element in self.json_data:
  171.             if len(element["text"]) > 200:
  172.                 continue
  173.             self.sentences.append(element["text"])
  174.             self.organs_indices.append(element["organ_indices"])
  175.             self.keywords.append(element["keywords"])
  176.         self.tokenizer = BertTokenizer.from_pretrained(bert_tokenizer_path_or_name)
  177.  
  178.  
  179. class VoxelSentenceMappingTrainClassDataset(VoxelSentenceMappingClassDataset, Dataset):
  180.     def __init__(
  181.         self,
  182.         json_path: str,
  183.         bert_tokenizer_path_or_name: str,
  184.         mask_probability: float,
  185.         num_classes: int,
  186.     ):
  187.         super().__init__(json_path, bert_tokenizer_path_or_name, num_classes)
  188.         self.mask_probability = mask_probability
  189.  
  190.     def __len__(self):
  191.         return len(self.sentences)
  192.  
  193.     def __getitem__(self, idx: int):
  194.         mask = {
  195.             word: np.random.choice(
  196.                 [0, 1], p=[1 - self.mask_probability, self.mask_probability]
  197.             )
  198.             for word in self.keywords[idx]
  199.         }
  200.         masked_sentence = " ".join(
  201.             [
  202.                 "[MASK]" if word in mask and mask[word] == 1 else word
  203.                 for word in nltk.word_tokenize(self.sentences[idx])
  204.             ]
  205.         )
  206.         tokenized_sentence = torch.tensor(
  207.             self.tokenizer.encode(masked_sentence, add_special_tokens=True)
  208.         )
  209.         organ_indices = torch.tensor(self.organs_indices[idx])
  210.         one_hot = torch.zeros(self.num_classes)
  211.         one_hot[organ_indices] = 1
  212.  
  213.         return tokenized_sentence, one_hot
  214.  
  215.  
  216. class VoxelSentenceMappingTestClassDataset(VoxelSentenceMappingClassDataset, Dataset):
  217.     def __init__(
  218.         self, json_path: str, bert_tokenizer_path_or_name: str, num_classes: int
  219.     ):
  220.         super().__init__(json_path, bert_tokenizer_path_or_name, num_classes)
  221.  
  222.     def __len__(self):
  223.         return len(self.sentences)
  224.  
  225.     def __getitem__(self, idx: int):
  226.         tokenized_sentence = torch.tensor(
  227.             self.tokenizer.encode(self.sentences[idx], add_special_tokens=True)
  228.         )
  229.         organ_indices = torch.tensor(self.organs_indices[idx])
  230.         one_hot = torch.zeros(self.num_classes)
  231.         one_hot[organ_indices] = 1
  232.  
  233.         return tokenized_sentence, one_hot
  234.  
  235.  
  236. class VoxelSentenceMappingTestMaskedClassDataset(
  237.     VoxelSentenceMappingClassDataset, Dataset
  238. ):
  239.     def __init__(
  240.         self, json_path: str, bert_tokenizer_path_or_name: str, num_classes: int
  241.     ):
  242.         super().__init__(json_path, bert_tokenizer_path_or_name, num_classes)
  243.  
  244.     def __len__(self):
  245.         return len(self.sentences)
  246.  
  247.     def __getitem__(self, idx: int):
  248.         mask = {word for word in self.keywords[idx]}
  249.         masked_sentence = " ".join(
  250.             [
  251.                 "[MASK]" if word in mask else word
  252.                 for word in nltk.word_tokenize(self.sentences[idx])
  253.             ]
  254.         )
  255.         tokenized_sentence = torch.tensor(
  256.             self.tokenizer.encode(masked_sentence, add_special_tokens=True)
  257.         )
  258.         organ_indices = torch.tensor(self.organs_indices[idx])
  259.         one_hot = torch.zeros(self.num_classes)
  260.         one_hot[organ_indices] = 1
  261.  
  262.         return tokenized_sentence, one_hot
  263.  
  264.  
  265. def collate_pad_sentence_class_batch(batch: Tuple[torch.Tensor, torch.Tensor]):
  266.     sentences, organ_indices = zip(*batch)
  267.     padded_sentences = torch.nn.utils.rnn.pad_sequence(sentences, batch_first=True)
  268.  
  269.     return padded_sentences, torch.stack([*organ_indices], dim=0)
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement