Not a member of Pastebin yet?
Sign Up,
it unlocks many cool features!
- # coding: utf-8
- # In[1]:
- import os
- import imageio
- import time
- import torch
- import torch.nn as nn
- import torch.nn.functional as F
- import torchvision
- from PIL import Image
- # In[2]:
- # dataset
- class TrainDataset(torch.utils.data.dataset.Dataset):
- def __init__(self, root_dir=os.path.join(os.getcwd(), "xray_images"), transform=None):
- self.root_dir = root_dir
- self.transform = transform
- def __len__(self):
- return 15000
- def __getitem__(self, index):
- img64_path = os.path.join(self.root_dir, 'train_images_64x64', f'train_{index+4000:05d}.png')
- img128_path = os.path.join(self.root_dir, 'train_images_128x128', f'train_{index+4000:05d}.png')
- img64 = imageio.imread(img64_path)
- img128 = imageio.imread(img128_path)
- if self.transform:
- img64 = self.transform(img64)
- img128 = self.transform(img128)
- return img64[0:1, :, :], img128[0:1, :, :]
- class ValidDataset(torch.utils.data.dataset.Dataset):
- def __init__(self, root_dir=os.path.join(os.getcwd(), "xray_images"), transform=None):
- self.root_dir = root_dir
- self.transform = transform
- def __len__(self):
- return 1000
- def __getitem__(self, index):
- img64_path = os.path.join(self.root_dir, 'train_images_64x64', f'train_{index+19001:05d}.png')
- img128_path = os.path.join(self.root_dir, 'train_images_128x128', f'train_{index+19001:05d}.png')
- img64 = imageio.imread(img64_path)
- img128 = imageio.imread(img128_path)
- if self.transform:
- img64 = self.transform(img64)
- img128 = self.transform(img128)
- return img64[0:1, :, :], img128[0:1, :, :]
- class TestDataset(torch.utils.data.dataset.Dataset):
- def __init__(self, root_dir=os.path.join(os.getcwd(), "xray_images"), transform=None):
- self.root_dir = root_dir
- self.transform = transform
- def __len__(self):
- return 3999
- def __getitem__(self, index):
- img64_path = os.path.join(self.root_dir, 'test_images_64x64', f'test_{index+1:05d}.png')
- img64 = imageio.imread(img64_path)
- if self.transform:
- img64 = self.transform(img64)
- return img64[0:1, :, :]
- # In[3]:
- class Net(nn.Module):
- def __init__(self):
- super().__init__()
- self.conv1 = nn.Conv2d(1, 64, 9, 1, 4)
- self.conv2 = nn.Conv2d(64, 64, 5, 1, 2)
- self.conv3 = nn.Conv2d(64, 1, 5, 1, 2)
- def forward(self, x):
- x = F.relu(self.conv1(x))
- x = F.relu(self.conv2(x))
- x = F.relu(self.conv3(x))
- return x
- # In[5]:
- device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
- # hyperparameters
- num_epochs = 5
- batch_size = 30
- learning_rate = 0.001
- criterion = nn.MSELoss(reduction='sum').to(device)
- upscale_factor = 2
- def img_preprocess(data, interpolation='bicubic'):
- if interpolation == 'bicubic':
- interpolation = Image.BICUBIC
- elif interpolation == 'bilinear':
- interpolation = Image.BILINEAR
- elif interpolation == 'nearest':
- interpolation = Image.NEAREST
- size = list(data.shape)
- if len(size) == 4:
- target_height = int(size[2] * upscale_factor)
- target_width = int(size[3] * upscale_factor)
- out_data = torch.FloatTensor(size[0], size[1], target_height, target_width)
- for i, img in enumerate(data):
- transform = torchvision.transforms.Compose([torchvision.transforms.ToPILImage(),
- torchvision.transforms.Resize((target_width, target_height), interpolation=interpolation),
- torchvision.transforms.ToTensor()])
- out_data[i, :, :, :] = transform(img)
- return out_data
- else:
- target_height = int(size[1] * upscale_factor)
- target_width = int(size[2] * upscale_factor)
- transform = torchvision.transforms.Compose([torchvision.transforms.ToPILImage(),
- torchvision.transforms.Resize((target_width, target_height), interpolation=interpolation),
- torchvision.transforms.ToTensor()])
- return transform(data)
- # In[6]:
- train_dataset = TrainDataset(root_dir='./xray_images',
- transform=torchvision.transforms.Compose([
- torchvision.transforms.ToTensor()
- ]))
- train_loader = torch.utils.data.DataLoader(dataset=train_dataset,
- batch_size=batch_size,
- shuffle=True)
- valid_dataset = ValidDataset(root_dir='./xray_images',
- transform=torchvision.transforms.ToTensor())
- valid_loader = torch.utils.data.DataLoader(dataset=valid_dataset,
- batch_size=batch_size)
- test_dataset = TestDataset(root_dir='./xray_images',
- transform=torchvision.transforms.ToTensor())
- test_loader = torch.utils.data.DataLoader(dataset=test_dataset,
- batch_size=batch_size,
- shuffle=True)
- # In[6]:
- model = Net().to(device)
- optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)
- # load the model
- # model.load_state_dict(torch.load('SRCNN2.pt', map_location=device))
- def adjust_learning_rate(optimizer, lr):
- for param_group in optimizer.param_groups:
- param_group['lr'] = lr
- # In[7]:
- # Train the model
- adjust_learning_rate(optimizer, 0.0003)
- model.train()
- total_step = len(train_loader)
- for epoch in range(num_epochs):
- start = time.time()
- for i, (img64, img128) in enumerate(train_loader):
- img64, img128 = img_preprocess(img64).to(device), img128.to(device)
- # Forward pass
- outputs = model(img64)
- loss = criterion(outputs, img128)
- # Backward and optimize
- optimizer.zero_grad()
- loss.backward()
- optimizer.step()
- # epoch end
- end = time.time()
- with torch.no_grad():
- valid_loss = 0
- for _, (valid64, valid128) in enumerate(valid_loader):
- valid64, valid128 = img_preprocess(valid64).to(device), valid128.to(device)
- valid_outputs = model(valid64)
- valid_loss += criterion(valid_outputs, valid128).item()
- print("Epoch [{}/{}], Step [{}/{}] loss: {:.2f} E[Loss]: {:.2f} Time/epoch: {:.2f} secs"
- .format(epoch+1, num_epochs, i+1, total_step, loss.item(), valid_loss*2, end-start))
- # Save the model checkpoint
- torch.save(model.state_dict(), 'SRCNN2.pt')
Add Comment
Please, Sign In to add comment