Guest User

Untitled

a guest
Jun 23rd, 2018
92
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
text 2.40 KB | None | 0 0
  1. device = torch.device("cuda")
  2.  
  3. class SegDataset(Dataset):
  4. def __init__(self, csv_loc, data_dir):
  5. self.data_dir = data_dir
  6. self.images_data = read_csv(csv_loc)
  7. self.images = self.prepare_images()
  8.  
  9. def transform(self, raw, seg):
  10. t = transforms.CenterCrop(128)
  11. raw = t(raw)
  12. seg = t(seg)
  13. raw = trans_f.to_tensor(raw).mul(255).float().to(device)
  14. seg = trans_f.to_tensor(seg).mul(255).long().to(device)
  15. return {'raw': raw, 'seg': seg}
  16.  
  17. def prepare_images(self):
  18. images = []
  19.  
  20. def read_image_by_id(idx, raw_image=True):
  21. img_name = os.path.join(self.data_dir,
  22. self.images_data[idx][int(not raw_image)])
  23. return Image.open(img_name)
  24.  
  25. for idx in range(len(self.images_data)):
  26. raw = read_image_by_id(idx)
  27. seg = read_image_by_id(idx, False)
  28. images.append(self.transform(raw, seg))
  29. random.shuffle(images)
  30. return images
  31.  
  32. def __len__(self):
  33. return len(self.images)
  34.  
  35. def __getitem__(self, idx):
  36. return self.images[idx]
  37.  
  38. criterion = nn.CrossEntropyLoss()
  39.  
  40. def train_model(model):
  41. train_data = SegDataset(csv_loc='Data/train.csv', data_dir='Data')
  42. train_iter = torch.utils.data.DataLoader(train_data, batch_size=BATCH_SIZE)
  43. val_data = SegDataset(csv_loc='Data/val.csv', data_dir='Data')
  44. val_iter = torch.utils.data.DataLoader(val_data, batch_size=BATCH_SIZE)
  45.  
  46. for epoch in range(EPOCHS):
  47. train_stats = run_func_on_data(train_batch, model, train_iter)
  48. val_stats = run_func_on_data(validate_batch, model, val_iter)
  49.  
  50. def run_proc_on_data(func, model, data_iter):
  51. print(func.__name__)
  52. loss, jac = 0, 0
  53. for i, batch in enumerate(data_iter):
  54. curr_loss, curr_jac = func(i, model, batch)
  55. loss += curr_loss
  56. jac += curr_jac
  57.  
  58. loss /= len(data_iter)
  59. jac /= len(data_iter)
  60.  
  61. print("loss: " + str(loss) + " jaccard: "+ str(jac))
  62.  
  63. def train_batch(batch_id, model, batch):
  64. model.zero_grad()
  65. pred = model(batch['raw'])
  66. loss = criterion(pred.view(-1, 3), batch['seg'].view(-1))
  67. loss.backward()
  68. model.optim.step()
  69.  
  70. return loss.item(), jac.item()
  71.  
  72. def validate_batch(batch_id, model, batch):
  73. pred = model(batch['raw'])
  74. loss = criterion(pred.view(-1, 3), batch['seg'].view(-1))
  75.  
  76. return loss.item(), jac.item()
Add Comment
Please, Sign In to add comment