daily pastebin goal
53%
SHARE
TWEET

Untitled

a guest Dec 18th, 2018 64 Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
  1. # coding: utf-8
  2.  
  3. # In[1]:
  4.  
  5.  
  6. import os
  7. import imageio
  8. import time
  9.  
  10. import torch
  11. import torch.nn as nn
  12. import torch.nn.functional as F
  13. import torchvision
  14.  
  15. from PIL import Image
  16.  
  17.  
  18. # In[2]:
  19.  
  20.  
  21. # dataset
  22. class TrainDataset(torch.utils.data.dataset.Dataset):
  23.     def __init__(self, root_dir=os.path.join(os.getcwd(), "xray_images"), transform=None):
  24.         self.root_dir = root_dir
  25.         self.transform = transform
  26.    
  27.     def __len__(self):
  28.         return 15000
  29.    
  30.     def __getitem__(self, index):
  31.         img64_path = os.path.join(self.root_dir, 'train_images_64x64', f'train_{index+4000:05d}.png')
  32.         img128_path = os.path.join(self.root_dir, 'train_images_128x128', f'train_{index+4000:05d}.png')
  33.         img64 = imageio.imread(img64_path)
  34.         img128 = imageio.imread(img128_path)
  35.         if self.transform:
  36.             img64 = self.transform(img64)
  37.             img128 = self.transform(img128)
  38.         return img64[0:1, :, :], img128[0:1, :, :]
  39.    
  40. class ValidDataset(torch.utils.data.dataset.Dataset):
  41.     def __init__(self, root_dir=os.path.join(os.getcwd(), "xray_images"), transform=None):
  42.         self.root_dir = root_dir
  43.         self.transform = transform
  44.    
  45.     def __len__(self):
  46.         return 1000
  47.    
  48.     def __getitem__(self, index):
  49.         img64_path = os.path.join(self.root_dir, 'train_images_64x64', f'train_{index+19001:05d}.png')
  50.         img128_path = os.path.join(self.root_dir, 'train_images_128x128', f'train_{index+19001:05d}.png')
  51.         img64 = imageio.imread(img64_path)
  52.         img128 = imageio.imread(img128_path)
  53.         if self.transform:
  54.             img64 = self.transform(img64)
  55.             img128 = self.transform(img128)
  56.         return img64[0:1, :, :], img128[0:1, :, :]
  57.    
  58. class TestDataset(torch.utils.data.dataset.Dataset):
  59.     def __init__(self, root_dir=os.path.join(os.getcwd(), "xray_images"), transform=None):
  60.         self.root_dir = root_dir
  61.         self.transform = transform
  62.    
  63.     def __len__(self):
  64.         return 3999
  65.    
  66.     def __getitem__(self, index):
  67.         img64_path = os.path.join(self.root_dir, 'test_images_64x64', f'test_{index+1:05d}.png')
  68.         img64 = imageio.imread(img64_path)
  69.         if self.transform:
  70.             img64 = self.transform(img64)
  71.         return img64[0:1, :, :]
  72.  
  73.  
  74. # In[3]:
  75.  
  76.  
  77. class Net(nn.Module):
  78.     def __init__(self):
  79.         super().__init__()
  80.         self.conv1 = nn.Conv2d(1, 64, 9, 1, 4)
  81.         self.conv2 = nn.Conv2d(64, 64, 5, 1, 2)
  82.         self.conv3 = nn.Conv2d(64, 1, 5, 1, 2)
  83.    
  84.     def forward(self, x):
  85.         x = F.relu(self.conv1(x))
  86.         x = F.relu(self.conv2(x))
  87.         x = F.relu(self.conv3(x))
  88.         return x
  89.  
  90.  
  91. # In[5]:
  92.  
  93.  
  94. device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
  95. # hyperparameters
  96.  
  97. num_epochs = 5
  98. batch_size = 30
  99. learning_rate = 0.001
  100. criterion = nn.MSELoss(reduction='sum').to(device)
  101.  
  102. upscale_factor = 2
  103.  
  104. def img_preprocess(data, interpolation='bicubic'):
  105.     if interpolation == 'bicubic':
  106.         interpolation = Image.BICUBIC
  107.     elif interpolation == 'bilinear':
  108.         interpolation = Image.BILINEAR
  109.     elif interpolation == 'nearest':
  110.         interpolation = Image.NEAREST
  111.  
  112.     size = list(data.shape)
  113.  
  114.     if len(size) == 4:
  115.         target_height = int(size[2] * upscale_factor)
  116.         target_width = int(size[3] * upscale_factor)
  117.         out_data = torch.FloatTensor(size[0], size[1], target_height, target_width)
  118.         for i, img in enumerate(data):
  119.             transform = torchvision.transforms.Compose([torchvision.transforms.ToPILImage(),
  120.                                             torchvision.transforms.Resize((target_width, target_height), interpolation=interpolation),
  121.                                             torchvision.transforms.ToTensor()])
  122.  
  123.             out_data[i, :, :, :] = transform(img)
  124.         return out_data
  125.     else:
  126.         target_height = int(size[1] * upscale_factor)
  127.         target_width = int(size[2] * upscale_factor)
  128.         transform = torchvision.transforms.Compose([torchvision.transforms.ToPILImage(),
  129.                                         torchvision.transforms.Resize((target_width, target_height), interpolation=interpolation),
  130.                                         torchvision.transforms.ToTensor()])
  131.         return transform(data)
  132.  
  133.  
  134. # In[6]:
  135.  
  136.  
  137. train_dataset = TrainDataset(root_dir='./xray_images',
  138.                              transform=torchvision.transforms.Compose([
  139.                                      torchvision.transforms.ToTensor()
  140.                                  ]))
  141. train_loader = torch.utils.data.DataLoader(dataset=train_dataset,
  142.                                            batch_size=batch_size,
  143.                                            shuffle=True)
  144.  
  145. valid_dataset = ValidDataset(root_dir='./xray_images',
  146.                             transform=torchvision.transforms.ToTensor())
  147. valid_loader = torch.utils.data.DataLoader(dataset=valid_dataset,
  148.                                            batch_size=batch_size)
  149.  
  150. test_dataset = TestDataset(root_dir='./xray_images',
  151.                              transform=torchvision.transforms.ToTensor())
  152. test_loader = torch.utils.data.DataLoader(dataset=test_dataset,
  153.                                           batch_size=batch_size,
  154.                                           shuffle=True)
  155.  
  156.  
  157. # In[6]:
  158.  
  159.  
  160. model = Net().to(device)
  161. optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)
  162.  
  163. # load the model
  164. # model.load_state_dict(torch.load('SRCNN2.pt', map_location=device))
  165.  
  166. def adjust_learning_rate(optimizer, lr):
  167.     for param_group in optimizer.param_groups:
  168.         param_group['lr'] = lr
  169.  
  170.  
  171. # In[7]:
  172.  
  173.  
  174. # Train the model
  175.  
  176. adjust_learning_rate(optimizer, 0.0003)
  177.  
  178. model.train()
  179. total_step = len(train_loader)
  180. for epoch in range(num_epochs):
  181.     start = time.time()
  182.     for i, (img64, img128) in enumerate(train_loader):
  183.         img64, img128 = img_preprocess(img64).to(device), img128.to(device)
  184.        
  185.         # Forward pass
  186.         outputs = model(img64)
  187.         loss = criterion(outputs, img128)
  188.        
  189.         # Backward and optimize
  190.         optimizer.zero_grad()
  191.         loss.backward()
  192.         optimizer.step()
  193.        
  194.     # epoch end
  195.     end = time.time()
  196.     with torch.no_grad():
  197.         valid_loss = 0
  198.         for _, (valid64, valid128) in enumerate(valid_loader):
  199.             valid64, valid128 = img_preprocess(valid64).to(device), valid128.to(device)
  200.             valid_outputs = model(valid64)
  201.             valid_loss += criterion(valid_outputs, valid128).item()
  202.         print("Epoch [{}/{}], Step [{}/{}] loss: {:.2f} E[Loss]: {:.2f} Time/epoch: {:.2f} secs"
  203.                .format(epoch+1, num_epochs, i+1, total_step, loss.item(), valid_loss*2, end-start))
  204.    
  205.     # Save the model checkpoint
  206.     torch.save(model.state_dict(), 'SRCNN2.pt')
RAW Paste Data
We use cookies for various purposes including analytics. By continuing to use Pastebin, you agree to our use of cookies as described in the Cookies Policy. OK, I Understand
 
Top