Guest User

Untitled

a guest
May 23rd, 2018
87
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
text 7.58 KB | None | 0 0
  1. import os
  2. import re
  3. import numpy as np
  4. import shapely.geometry
  5.  
  6. import torch
  7. import torch.nn as nn
  8. import torch.nn.functional as F
  9. from torch.utils.data import Dataset, DataLoader
  10. from torchvision import transforms, models
  11.  
  12. from skimage import io, transform
  13. from tqdm import tqdm
  14.  
  15. FEATURE_MEANS = np.array([302.04, 272.97, 102.48, 49.71, 29.17])
  16. FEATURE_STDS = np.array([33.50, 40.19, 96.50, 22.82, 11.89])
  17.  
  18.  
  19. def point_mid(a, b):
  20. return ((a[0] + b[0])/2, (a[1] + b[1])/2)
  21.  
  22.  
  23. def point_dist(a, b):
  24. dx, dy = a[0] - b[0], a[1] - b[1]
  25. return np.sqrt(dx**2 + dy**2)
  26.  
  27.  
  28. def point_angle(a, b):
  29. dx, dy = a[0] - b[0], a[1] - b[1]
  30. return np.degrees(np.arctan2(dy, dx)) + 90
  31.  
  32.  
  33. def rotate_point(p, theta):
  34. theta_rad = np.radians(theta)
  35. return (np.cos(theta_rad) * p[0] - np.sin(theta_rad) * p[1],
  36. np.sin(theta_rad) * p[0] + np.cos(theta_rad) * p[1])
  37.  
  38.  
  39. def rect_to_points(x, y, theta, w, h):
  40. points = [(x - w/2, y - h/2),
  41. (x - w/2, y + h/2),
  42. (x + w/2, y + h/2),
  43. (x + w/2, y - h/2)]
  44. return [rotate_point(p, theta) for p in points]
  45.  
  46.  
  47. def test_proposals(props, gts):
  48. results = []
  49. for raw_prop, raw_gt in zip(props, gts):
  50. prop = raw_prop.detach().numpy() * FEATURE_STDS + FEATURE_MEANS
  51. gt = raw_gt.detach().numpy() * FEATURE_STDS + FEATURE_MEANS
  52.  
  53. prop_poly = shapely.geometry.Polygon(rect_to_points(*prop))
  54. gt_poly = shapely.geometry.Polygon(rect_to_points(*gt))
  55. intersect_area = prop_poly.intersection(gt_poly).area
  56. union_area = prop_poly.union(gt_poly).area
  57.  
  58. print(intersect_area)
  59. print(prop)
  60. print(gt)
  61.  
  62. min_angle, max_angle = (prop[2], gt[2]) if prop[2] < gt[2] else (gt[2], prop[2])
  63. angles_work = (max_angle - min_angle < 30 or
  64. min_angle + 360 - max_angle < 30)
  65.  
  66. result = intersect_area / union_area >= 0.25 and angles_work
  67. results.append(result)
  68.  
  69. return results
  70.  
  71.  
  72. class CornellGraspDataset2d(Dataset):
  73. def __init__(self, root_dir, transform=None):
  74. self.root_dir = root_dir
  75. self.transform = transform
  76. self.ids = []
  77. for fname in os.listdir(self.root_dir):
  78. if fname.endswith('.png'):
  79. id = re.findall('\d+', fname)[0]
  80. self.ids.append(id)
  81. self.ids.sort() # for determinism
  82.  
  83. def __len__(self):
  84. return len(self.ids)
  85.  
  86. @staticmethod
  87. def parse_rects(rect_path, id):
  88. rects = []
  89. with open(rect_path) as f:
  90. lines = f.read().splitlines()
  91.  
  92. for i in range(0, len(lines), 4):
  93. rect_lines = lines[i:(i + 4)]
  94. points = []
  95.  
  96. valid = True
  97. for rect_line in rect_lines:
  98. point_values = rect_line.strip().split(' ')
  99. point = tuple(float(value) for value in point_values)
  100. if np.any(np.isnan(point)):
  101. valid = False
  102. points.append(point)
  103.  
  104. if not valid:
  105. continue
  106.  
  107. x, y = point_mid(points[0], points[2])
  108. theta = point_angle(points[1], points[2])
  109. w = point_dist(points[0], points[1])
  110. h = point_dist(points[1], points[2])
  111.  
  112. rects.append(torch.FloatTensor([x, y, theta, w, h]))
  113. with open('derp.csv', 'a') as f:
  114. f.write(','.join([str(v) for v in (x, y, theta, w, h)]) + '\n')
  115.  
  116. return rects
  117.  
  118. def __getitem__(self, idx):
  119. id = self.ids[idx]
  120. img_path = os.path.join(self.root_dir, 'pcd{}r.png'.format(id))
  121. img = io.imread(img_path)
  122.  
  123. pos_path = os.path.join(self.root_dir, 'pcd{}cpos.txt'.format(id))
  124. pos = self.parse_rects(pos_path, id)
  125. # gt = pos[np.random.choice(len(pos))]
  126. gt = pos[0]
  127.  
  128. if self.transform:
  129. img = transforms.functional.to_pil_image(img)
  130. img, gt = self.transform((img, gt))
  131. img = transforms.functional.to_tensor(img)
  132.  
  133. return img, gt
  134.  
  135.  
  136. class RandomTranslate(object):
  137. def __init__(self, shift):
  138. self.shift = shift
  139.  
  140. def __call__(self, sample):
  141. img, gt = sample
  142. x_shift = np.random.randint(-self.shift, self.shift)
  143. y_shift = np.random.randint(-self.shift, self.shift)
  144. shift = (x_shift, y_shift)
  145. new_img = transforms.functional.affine(img, 0, shift, 1, 0)
  146. new_gt = np.copy(gt)
  147. new_gt[0] += x_shift
  148. new_gt[1] += y_shift
  149. return new_img, new_gt
  150.  
  151.  
  152. class Resize(object):
  153. def __init__(self, new_size):
  154. self.new_size = new_size
  155.  
  156. def __call__(self, sample):
  157. img, gt = sample
  158. h, w = img.size[:2]
  159. new_img = transforms.functional.resize(img, (self.new_size,
  160. self.new_size))
  161. new_gt = np.copy(gt)
  162. new_gt[0] *= self.new_size/w
  163. new_gt[1] *= self.new_size/h
  164. new_gt[3] *= self.new_size/w
  165. new_gt[4] *= self.new_size/h
  166. return new_img, new_gt
  167.  
  168.  
  169. class CenterCrop(object):
  170. def __init__(self, new_size):
  171. self.new_size = new_size
  172.  
  173. def __call__(self, sample):
  174. img, gt = sample
  175. h, w = img.size[:2]
  176. i = (h - self.new_size)//2
  177. j = (w - self.new_size)//2
  178. new_img = transforms.functional.crop(img, i, j, self.new_size, self.new_size)
  179. new_gt = np.copy(gt)
  180. new_gt[0] -= j
  181. new_gt[1] -= i
  182. return new_img, new_gt
  183.  
  184.  
  185. class GraspNormalize(object):
  186. def __init__(self, mean, std):
  187. self.mean = torch.FloatTensor(mean)
  188. self.std = torch.FloatTensor(std)
  189.  
  190. def __call__(self, sample):
  191. img, gt = sample
  192. return img, (gt - self.mean)/self.std
  193.  
  194.  
  195. class ResNetGrasp(nn.Module):
  196. def __init__(self):
  197. super(ResNetGrasp, self).__init__()
  198. self.resnet = models.resnet50(pretrained=True)
  199. self.features = nn.Sequential(*list(self.resnet.children())[:-1])
  200.  
  201. num_ftrs = self.resnet.fc.in_features
  202. self.classifier = nn.Sequential(
  203. nn.Linear(num_ftrs, 128),
  204. nn.ReLU(),
  205. nn.Linear(128, 5))
  206.  
  207. def forward(self, x):
  208. x = self.features(x)
  209. x = x.view(x.size(0), -1)
  210. return self.classifier(x)
  211.  
  212.  
  213. def train_model(model, dataloader, num_epochs=25, criterion=None,
  214. optimizer=None):
  215. if criterion is None:
  216. criterion = nn.MSELoss()
  217.  
  218. if optimizer is None:
  219. optimizer = torch.optim.Adam(model.parameters())
  220.  
  221. for epoch in range(num_epochs):
  222. print('Epoch {}/{}'.format(epoch, num_epochs-1))
  223. model.train()
  224.  
  225. running_loss = 0.0
  226. running_corrects = 0
  227.  
  228. for imgs, gts in tqdm(list(dataloader)[:1]):
  229. optimizer.zero_grad()
  230. props = model(imgs)
  231. loss = criterion(props, gts)
  232. loss.backward()
  233. optimizer.step()
  234. running_loss += loss.item() * imgs.size(0)
  235. running_corrects += sum(test_proposals(props, gts))
  236.  
  237. epoch_loss = running_loss / 10
  238. epoch_acc = running_corrects / 10
  239.  
  240. print('Loss: {}, Acc: {}'.format(epoch_loss, epoch_acc))
  241.  
  242.  
  243. composed = transforms.Compose([CenterCrop(224),
  244. RandomTranslate(50),
  245. Resize(224),
  246. GraspNormalize(mean=FEATURE_MEANS,
  247. std=FEATURE_STDS)])
  248.  
  249. cornell_dataset = CornellGraspDataset2d('data', transform=composed)
  250. cornell_dataloader = DataLoader(cornell_dataset, batch_size=8,
  251. shuffle=False, num_workers=16)
  252.  
  253. model = ResNetGrasp()
  254. train_model(model, cornell_dataloader)
Add Comment
Please, Sign In to add comment