Guest User

HMTORCHDATASET

a guest
May 10th, 2021
103
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
Python 3.68 KB | None | 0 0
  1. class HMTorchDataset(Dataset):
  2.     def __init__(self, data_path,img_path,tsv_path):
  3.         super().__init__()
  4.         # self.name = data_path
  5.         # self.splits = splits.split(",")
  6.         self.path = data_path
  7.         self.img_path = img_path
  8.         # Loading datasets to data
  9.         self.raw_data = []
  10.         # for split in self.splits:
  11.         #     path = os.path.join("data/", f"{split}.jsonl")
  12.         #     self.raw_data.extend(
  13.         #             [json.loads(jline) for jline in open(path, "r").read().split('\n')]
  14.         #     )
  15.         # print("Load %d data from split(s) %s." % (len(self.raw_data), self.name))
  16.         self.raw_data  = [json.loads(jline) for jline in open(self.path,"r").read().split('\n')]
  17.         # self.raw_data  = [json.loads(jline) for jline in open(self.path,"r")]
  18.        
  19.         # List to dict (for evaluation and others)
  20.         self.id2datum = {datum["id"]: datum for datum in self.raw_data}
  21.  
  22.         # Loading detection features to img_data
  23.         img_data = []
  24.  
  25.         # path = "data/HM_img.tsv"
  26.         img_data.extend(load_obj_tsv(tsv_path, self.id2datum.keys()))
  27.  
  28.         # Convert img list to dict
  29.         self.imgid2img = {}
  30.         for img_datum in img_data:
  31.             # Adding int here to convert 0625 to 625
  32.             self.imgid2img[int(img_datum['img_id'])] = img_datum
  33.  
  34.  
  35.         # Only keep the data with loaded image features
  36.         self.data = []
  37.         for datum in self.raw_data:
  38.             # In HM the Img Id field is simply "id"
  39.             if datum['id'] in self.imgid2img:
  40.                 self.data.append(datum)
  41.  
  42.         print("Use %d data in torch dataset" % (len(self.data)))
  43.         print()
  44.  
  45.     def __len__(self):
  46.         return len(self.data)
  47.  
  48.  
  49.     def __getitem__(self, item: int):
  50.  
  51.         datum = self.data[item]
  52.  
  53.         img_id = datum['id']
  54.         # title = datum['title']
  55.         text = datum['text']
  56.  
  57.  
  58.         # Get image info
  59.         img_info = self.imgid2img[img_id]
  60.         obj_num = img_info['num_boxes']
  61.         feats = img_info['features'].copy()
  62.         boxes = img_info['boxes'].copy()
  63.         assert obj_num == len(boxes) == len(feats)
  64.  
  65.  
  66.         # Normalize the boxes (to 0 ~ 1)
  67.         img_h, img_w = img_info['img_h'], img_info['img_w']
  68.  
  69.         if args.num_pos == 5:
  70.             # For DeVLBert taken from VilBERT
  71.             image_location = np.zeros((boxes.shape[0], 5), dtype=np.float32)
  72.             image_location[:,:4] = boxes
  73.             image_location[:,4] = (image_location[:,3] - image_location[:,1]) * (image_location[:,2] - image_location[:,0]) / (float(img_w) * float(img_h))
  74.             boxes = image_location
  75.  
  76.         boxes = boxes.copy()
  77.         boxes[:, (0, 2)] /= img_w
  78.         boxes[:, (1, 3)] /= img_h
  79.         np.testing.assert_array_less(boxes, 1+1e-5)
  80.         np.testing.assert_array_less(-boxes, 0+1e-5)
  81.  
  82.  
  83.         if args.num_pos == 6:
  84.             # Add width & height
  85.             width = (boxes[:, 2] - boxes[:, 0]).reshape(-1,1)
  86.             height = (boxes[:, 3] - boxes[:, 1]).reshape(-1,1)
  87.  
  88.             boxes = np.concatenate((boxes, width, height), axis=-1)
  89.  
  90.             # In UNITER they use 7 Pos Feats (See _get_img_feat function in their repo)
  91.             if args.model == "U":
  92.                 boxes = np.concatenate([boxes, boxes[:, 4:5]*boxes[:, 5:]], axis=-1)
  93.  
  94.         # Provide label (target) - From hm_data
  95.         if 'label' in datum:
  96.             target = torch.tensor(datum["label"], dtype=torch.float)
  97.             # return img_id, feats, boxes, title, target
  98.             return img_id, feats, boxes, text, target
  99.         else:
  100.             # return img_id, feats, boxes, title
  101.             return img_id, feats, boxes, text
Advertisement
Add Comment
Please, Sign In to add comment