Not a member of Pastebin yet?
Sign Up,
it unlocks many cool features!
- class HMTorchDataset(Dataset):
- def __init__(self, data_path,img_path,tsv_path):
- super().__init__()
- # self.name = data_path
- # self.splits = splits.split(",")
- self.path = data_path
- self.img_path = img_path
- # Loading datasets to data
- self.raw_data = []
- # for split in self.splits:
- # path = os.path.join("data/", f"{split}.jsonl")
- # self.raw_data.extend(
- # [json.loads(jline) for jline in open(path, "r").read().split('\n')]
- # )
- # print("Load %d data from split(s) %s." % (len(self.raw_data), self.name))
- self.raw_data = [json.loads(jline) for jline in open(self.path,"r").read().split('\n')]
- # self.raw_data = [json.loads(jline) for jline in open(self.path,"r")]
- # List to dict (for evaluation and others)
- self.id2datum = {datum["id"]: datum for datum in self.raw_data}
- # Loading detection features to img_data
- img_data = []
- # path = "data/HM_img.tsv"
- img_data.extend(load_obj_tsv(tsv_path, self.id2datum.keys()))
- # Convert img list to dict
- self.imgid2img = {}
- for img_datum in img_data:
- # Adding int here to convert 0625 to 625
- self.imgid2img[int(img_datum['img_id'])] = img_datum
- # Only keep the data with loaded image features
- self.data = []
- for datum in self.raw_data:
- # In HM the Img Id field is simply "id"
- if datum['id'] in self.imgid2img:
- self.data.append(datum)
- print("Use %d data in torch dataset" % (len(self.data)))
- print()
- def __len__(self):
- return len(self.data)
- def __getitem__(self, item: int):
- datum = self.data[item]
- img_id = datum['id']
- # title = datum['title']
- text = datum['text']
- # Get image info
- img_info = self.imgid2img[img_id]
- obj_num = img_info['num_boxes']
- feats = img_info['features'].copy()
- boxes = img_info['boxes'].copy()
- assert obj_num == len(boxes) == len(feats)
- # Normalize the boxes (to 0 ~ 1)
- img_h, img_w = img_info['img_h'], img_info['img_w']
- if args.num_pos == 5:
- # For DeVLBert taken from VilBERT
- image_location = np.zeros((boxes.shape[0], 5), dtype=np.float32)
- image_location[:,:4] = boxes
- image_location[:,4] = (image_location[:,3] - image_location[:,1]) * (image_location[:,2] - image_location[:,0]) / (float(img_w) * float(img_h))
- boxes = image_location
- boxes = boxes.copy()
- boxes[:, (0, 2)] /= img_w
- boxes[:, (1, 3)] /= img_h
- np.testing.assert_array_less(boxes, 1+1e-5)
- np.testing.assert_array_less(-boxes, 0+1e-5)
- if args.num_pos == 6:
- # Add width & height
- width = (boxes[:, 2] - boxes[:, 0]).reshape(-1,1)
- height = (boxes[:, 3] - boxes[:, 1]).reshape(-1,1)
- boxes = np.concatenate((boxes, width, height), axis=-1)
- # In UNITER they use 7 Pos Feats (See _get_img_feat function in their repo)
- if args.model == "U":
- boxes = np.concatenate([boxes, boxes[:, 4:5]*boxes[:, 5:]], axis=-1)
- # Provide label (target) - From hm_data
- if 'label' in datum:
- target = torch.tensor(datum["label"], dtype=torch.float)
- # return img_id, feats, boxes, title, target
- return img_id, feats, boxes, text, target
- else:
- # return img_id, feats, boxes, title
- return img_id, feats, boxes, text
Advertisement
Add Comment
Please, Sign In to add comment