Not a member of Pastebin yet?
Sign Up,
it unlocks many cool features!
- device = torch.device("cuda")
- class SegDataset(Dataset):
- def __init__(self, csv_loc, data_dir):
- self.data_dir = data_dir
- self.images_data = read_csv(csv_loc)
- self.images = self.prepare_images()
- def transform(self, raw, seg):
- t = transforms.CenterCrop(128)
- raw = t(raw)
- seg = t(seg)
- raw = trans_f.to_tensor(raw).mul(255).float().to(device)
- seg = trans_f.to_tensor(seg).mul(255).long().to(device)
- return {'raw': raw, 'seg': seg}
- def prepare_images(self):
- images = []
- def read_image_by_id(idx, raw_image=True):
- img_name = os.path.join(self.data_dir,
- self.images_data[idx][int(not raw_image)])
- return Image.open(img_name)
- for idx in range(len(self.images_data)):
- raw = read_image_by_id(idx)
- seg = read_image_by_id(idx, False)
- images.append(self.transform(raw, seg))
- random.shuffle(images)
- return images
- def __len__(self):
- return len(self.images)
- def __getitem__(self, idx):
- return self.images[idx]
- criterion = nn.CrossEntropyLoss()
- def train_model(model):
- train_data = SegDataset(csv_loc='Data/train.csv', data_dir='Data')
- train_iter = torch.utils.data.DataLoader(train_data, batch_size=BATCH_SIZE)
- val_data = SegDataset(csv_loc='Data/val.csv', data_dir='Data')
- val_iter = torch.utils.data.DataLoader(val_data, batch_size=BATCH_SIZE)
- for epoch in range(EPOCHS):
- train_stats = run_func_on_data(train_batch, model, train_iter)
- val_stats = run_func_on_data(validate_batch, model, val_iter)
- def run_proc_on_data(func, model, data_iter):
- print(func.__name__)
- loss, jac = 0, 0
- for i, batch in enumerate(data_iter):
- curr_loss, curr_jac = func(i, model, batch)
- loss += curr_loss
- jac += curr_jac
- loss /= len(data_iter)
- jac /= len(data_iter)
- print("loss: " + str(loss) + " jaccard: "+ str(jac))
- def train_batch(batch_id, model, batch):
- model.zero_grad()
- pred = model(batch['raw'])
- loss = criterion(pred.view(-1, 3), batch['seg'].view(-1))
- loss.backward()
- model.optim.step()
- return loss.item(), jac.item()
- def validate_batch(batch_id, model, batch):
- pred = model(batch['raw'])
- loss = criterion(pred.view(-1, 3), batch['seg'].view(-1))
- return loss.item(), jac.item()
Add Comment
Please, Sign In to add comment