Not a member of Pastebin yet?
Sign Up,
it unlocks many cool features!
- import os
- import re
- import numpy as np
- import shapely.geometry
- import torch
- import torch.nn as nn
- import torch.nn.functional as F
- from torch.utils.data import Dataset, DataLoader
- from torchvision import transforms, models
- from skimage import io, transform
- from tqdm import tqdm
- FEATURE_MEANS = np.array([302.04, 272.97, 102.48, 49.71, 29.17])
- FEATURE_STDS = np.array([33.50, 40.19, 96.50, 22.82, 11.89])
- def point_mid(a, b):
- return ((a[0] + b[0])/2, (a[1] + b[1])/2)
- def point_dist(a, b):
- dx, dy = a[0] - b[0], a[1] - b[1]
- return np.sqrt(dx**2 + dy**2)
- def point_angle(a, b):
- dx, dy = a[0] - b[0], a[1] - b[1]
- return np.degrees(np.arctan2(dy, dx)) + 90
- def rotate_point(p, theta):
- theta_rad = np.radians(theta)
- return (np.cos(theta_rad) * p[0] - np.sin(theta_rad) * p[1],
- np.sin(theta_rad) * p[0] + np.cos(theta_rad) * p[1])
- def rect_to_points(x, y, theta, w, h):
- points = [(x - w/2, y - h/2),
- (x - w/2, y + h/2),
- (x + w/2, y + h/2),
- (x + w/2, y - h/2)]
- return [rotate_point(p, theta) for p in points]
- def test_proposals(props, gts):
- results = []
- for raw_prop, raw_gt in zip(props, gts):
- prop = raw_prop.detach().numpy() * FEATURE_STDS + FEATURE_MEANS
- gt = raw_gt.detach().numpy() * FEATURE_STDS + FEATURE_MEANS
- prop_poly = shapely.geometry.Polygon(rect_to_points(*prop))
- gt_poly = shapely.geometry.Polygon(rect_to_points(*gt))
- intersect_area = prop_poly.intersection(gt_poly).area
- union_area = prop_poly.union(gt_poly).area
- print(intersect_area)
- print(prop)
- print(gt)
- min_angle, max_angle = (prop[2], gt[2]) if prop[2] < gt[2] else (gt[2], prop[2])
- angles_work = (max_angle - min_angle < 30 or
- min_angle + 360 - max_angle < 30)
- result = intersect_area / union_area >= 0.25 and angles_work
- results.append(result)
- return results
- class CornellGraspDataset2d(Dataset):
- def __init__(self, root_dir, transform=None):
- self.root_dir = root_dir
- self.transform = transform
- self.ids = []
- for fname in os.listdir(self.root_dir):
- if fname.endswith('.png'):
- id = re.findall('\d+', fname)[0]
- self.ids.append(id)
- self.ids.sort() # for determinism
- def __len__(self):
- return len(self.ids)
- @staticmethod
- def parse_rects(rect_path, id):
- rects = []
- with open(rect_path) as f:
- lines = f.read().splitlines()
- for i in range(0, len(lines), 4):
- rect_lines = lines[i:(i + 4)]
- points = []
- valid = True
- for rect_line in rect_lines:
- point_values = rect_line.strip().split(' ')
- point = tuple(float(value) for value in point_values)
- if np.any(np.isnan(point)):
- valid = False
- points.append(point)
- if not valid:
- continue
- x, y = point_mid(points[0], points[2])
- theta = point_angle(points[1], points[2])
- w = point_dist(points[0], points[1])
- h = point_dist(points[1], points[2])
- rects.append(torch.FloatTensor([x, y, theta, w, h]))
- with open('derp.csv', 'a') as f:
- f.write(','.join([str(v) for v in (x, y, theta, w, h)]) + '\n')
- return rects
- def __getitem__(self, idx):
- id = self.ids[idx]
- img_path = os.path.join(self.root_dir, 'pcd{}r.png'.format(id))
- img = io.imread(img_path)
- pos_path = os.path.join(self.root_dir, 'pcd{}cpos.txt'.format(id))
- pos = self.parse_rects(pos_path, id)
- # gt = pos[np.random.choice(len(pos))]
- gt = pos[0]
- if self.transform:
- img = transforms.functional.to_pil_image(img)
- img, gt = self.transform((img, gt))
- img = transforms.functional.to_tensor(img)
- return img, gt
- class RandomTranslate(object):
- def __init__(self, shift):
- self.shift = shift
- def __call__(self, sample):
- img, gt = sample
- x_shift = np.random.randint(-self.shift, self.shift)
- y_shift = np.random.randint(-self.shift, self.shift)
- shift = (x_shift, y_shift)
- new_img = transforms.functional.affine(img, 0, shift, 1, 0)
- new_gt = np.copy(gt)
- new_gt[0] += x_shift
- new_gt[1] += y_shift
- return new_img, new_gt
- class Resize(object):
- def __init__(self, new_size):
- self.new_size = new_size
- def __call__(self, sample):
- img, gt = sample
- h, w = img.size[:2]
- new_img = transforms.functional.resize(img, (self.new_size,
- self.new_size))
- new_gt = np.copy(gt)
- new_gt[0] *= self.new_size/w
- new_gt[1] *= self.new_size/h
- new_gt[3] *= self.new_size/w
- new_gt[4] *= self.new_size/h
- return new_img, new_gt
- class CenterCrop(object):
- def __init__(self, new_size):
- self.new_size = new_size
- def __call__(self, sample):
- img, gt = sample
- h, w = img.size[:2]
- i = (h - self.new_size)//2
- j = (w - self.new_size)//2
- new_img = transforms.functional.crop(img, i, j, self.new_size, self.new_size)
- new_gt = np.copy(gt)
- new_gt[0] -= j
- new_gt[1] -= i
- return new_img, new_gt
- class GraspNormalize(object):
- def __init__(self, mean, std):
- self.mean = torch.FloatTensor(mean)
- self.std = torch.FloatTensor(std)
- def __call__(self, sample):
- img, gt = sample
- return img, (gt - self.mean)/self.std
- class ResNetGrasp(nn.Module):
- def __init__(self):
- super(ResNetGrasp, self).__init__()
- self.resnet = models.resnet50(pretrained=True)
- self.features = nn.Sequential(*list(self.resnet.children())[:-1])
- num_ftrs = self.resnet.fc.in_features
- self.classifier = nn.Sequential(
- nn.Linear(num_ftrs, 128),
- nn.ReLU(),
- nn.Linear(128, 5))
- def forward(self, x):
- x = self.features(x)
- x = x.view(x.size(0), -1)
- return self.classifier(x)
- def train_model(model, dataloader, num_epochs=25, criterion=None,
- optimizer=None):
- if criterion is None:
- criterion = nn.MSELoss()
- if optimizer is None:
- optimizer = torch.optim.Adam(model.parameters())
- for epoch in range(num_epochs):
- print('Epoch {}/{}'.format(epoch, num_epochs-1))
- model.train()
- running_loss = 0.0
- running_corrects = 0
- for imgs, gts in tqdm(list(dataloader)[:1]):
- optimizer.zero_grad()
- props = model(imgs)
- loss = criterion(props, gts)
- loss.backward()
- optimizer.step()
- running_loss += loss.item() * imgs.size(0)
- running_corrects += sum(test_proposals(props, gts))
- epoch_loss = running_loss / 10
- epoch_acc = running_corrects / 10
- print('Loss: {}, Acc: {}'.format(epoch_loss, epoch_acc))
- composed = transforms.Compose([CenterCrop(224),
- RandomTranslate(50),
- Resize(224),
- GraspNormalize(mean=FEATURE_MEANS,
- std=FEATURE_STDS)])
- cornell_dataset = CornellGraspDataset2d('data', transform=composed)
- cornell_dataloader = DataLoader(cornell_dataset, batch_size=8,
- shuffle=False, num_workers=16)
- model = ResNetGrasp()
- train_model(model, cornell_dataloader)
Add Comment
Please, Sign In to add comment