Guest User

Untitled

a guest
Dec 18th, 2018
104
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
text 6.59 KB | None | 0 0
  1. # coding: utf-8
  2.  
  3. # In[1]:
  4.  
  5.  
  6. import os
  7. import imageio
  8. import time
  9.  
  10. import torch
  11. import torch.nn as nn
  12. import torch.nn.functional as F
  13. import torchvision
  14.  
  15. from PIL import Image
  16.  
  17.  
  18. # In[2]:
  19.  
  20.  
  21. # dataset
  22. class TrainDataset(torch.utils.data.dataset.Dataset):
  23. def __init__(self, root_dir=os.path.join(os.getcwd(), "xray_images"), transform=None):
  24. self.root_dir = root_dir
  25. self.transform = transform
  26.  
  27. def __len__(self):
  28. return 15000
  29.  
  30. def __getitem__(self, index):
  31. img64_path = os.path.join(self.root_dir, 'train_images_64x64', f'train_{index+4000:05d}.png')
  32. img128_path = os.path.join(self.root_dir, 'train_images_128x128', f'train_{index+4000:05d}.png')
  33. img64 = imageio.imread(img64_path)
  34. img128 = imageio.imread(img128_path)
  35. if self.transform:
  36. img64 = self.transform(img64)
  37. img128 = self.transform(img128)
  38. return img64[0:1, :, :], img128[0:1, :, :]
  39.  
  40. class ValidDataset(torch.utils.data.dataset.Dataset):
  41. def __init__(self, root_dir=os.path.join(os.getcwd(), "xray_images"), transform=None):
  42. self.root_dir = root_dir
  43. self.transform = transform
  44.  
  45. def __len__(self):
  46. return 1000
  47.  
  48. def __getitem__(self, index):
  49. img64_path = os.path.join(self.root_dir, 'train_images_64x64', f'train_{index+19001:05d}.png')
  50. img128_path = os.path.join(self.root_dir, 'train_images_128x128', f'train_{index+19001:05d}.png')
  51. img64 = imageio.imread(img64_path)
  52. img128 = imageio.imread(img128_path)
  53. if self.transform:
  54. img64 = self.transform(img64)
  55. img128 = self.transform(img128)
  56. return img64[0:1, :, :], img128[0:1, :, :]
  57.  
  58. class TestDataset(torch.utils.data.dataset.Dataset):
  59. def __init__(self, root_dir=os.path.join(os.getcwd(), "xray_images"), transform=None):
  60. self.root_dir = root_dir
  61. self.transform = transform
  62.  
  63. def __len__(self):
  64. return 3999
  65.  
  66. def __getitem__(self, index):
  67. img64_path = os.path.join(self.root_dir, 'test_images_64x64', f'test_{index+1:05d}.png')
  68. img64 = imageio.imread(img64_path)
  69. if self.transform:
  70. img64 = self.transform(img64)
  71. return img64[0:1, :, :]
  72.  
  73.  
  74. # In[3]:
  75.  
  76.  
  77. class Net(nn.Module):
  78. def __init__(self):
  79. super().__init__()
  80. self.conv1 = nn.Conv2d(1, 64, 9, 1, 4)
  81. self.conv2 = nn.Conv2d(64, 64, 5, 1, 2)
  82. self.conv3 = nn.Conv2d(64, 1, 5, 1, 2)
  83.  
  84. def forward(self, x):
  85. x = F.relu(self.conv1(x))
  86. x = F.relu(self.conv2(x))
  87. x = F.relu(self.conv3(x))
  88. return x
  89.  
  90.  
  91. # In[5]:
  92.  
  93.  
  94. device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
  95. # hyperparameters
  96.  
  97. num_epochs = 5
  98. batch_size = 30
  99. learning_rate = 0.001
  100. criterion = nn.MSELoss(reduction='sum').to(device)
  101.  
  102. upscale_factor = 2
  103.  
  104. def img_preprocess(data, interpolation='bicubic'):
  105. if interpolation == 'bicubic':
  106. interpolation = Image.BICUBIC
  107. elif interpolation == 'bilinear':
  108. interpolation = Image.BILINEAR
  109. elif interpolation == 'nearest':
  110. interpolation = Image.NEAREST
  111.  
  112. size = list(data.shape)
  113.  
  114. if len(size) == 4:
  115. target_height = int(size[2] * upscale_factor)
  116. target_width = int(size[3] * upscale_factor)
  117. out_data = torch.FloatTensor(size[0], size[1], target_height, target_width)
  118. for i, img in enumerate(data):
  119. transform = torchvision.transforms.Compose([torchvision.transforms.ToPILImage(),
  120. torchvision.transforms.Resize((target_width, target_height), interpolation=interpolation),
  121. torchvision.transforms.ToTensor()])
  122.  
  123. out_data[i, :, :, :] = transform(img)
  124. return out_data
  125. else:
  126. target_height = int(size[1] * upscale_factor)
  127. target_width = int(size[2] * upscale_factor)
  128. transform = torchvision.transforms.Compose([torchvision.transforms.ToPILImage(),
  129. torchvision.transforms.Resize((target_width, target_height), interpolation=interpolation),
  130. torchvision.transforms.ToTensor()])
  131. return transform(data)
  132.  
  133.  
  134. # In[6]:
  135.  
  136.  
  137. train_dataset = TrainDataset(root_dir='./xray_images',
  138. transform=torchvision.transforms.Compose([
  139. torchvision.transforms.ToTensor()
  140. ]))
  141. train_loader = torch.utils.data.DataLoader(dataset=train_dataset,
  142. batch_size=batch_size,
  143. shuffle=True)
  144.  
  145. valid_dataset = ValidDataset(root_dir='./xray_images',
  146. transform=torchvision.transforms.ToTensor())
  147. valid_loader = torch.utils.data.DataLoader(dataset=valid_dataset,
  148. batch_size=batch_size)
  149.  
  150. test_dataset = TestDataset(root_dir='./xray_images',
  151. transform=torchvision.transforms.ToTensor())
  152. test_loader = torch.utils.data.DataLoader(dataset=test_dataset,
  153. batch_size=batch_size,
  154. shuffle=True)
  155.  
  156.  
  157. # In[6]:
  158.  
  159.  
  160. model = Net().to(device)
  161. optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)
  162.  
  163. # load the model
  164. # model.load_state_dict(torch.load('SRCNN2.pt', map_location=device))
  165.  
  166. def adjust_learning_rate(optimizer, lr):
  167. for param_group in optimizer.param_groups:
  168. param_group['lr'] = lr
  169.  
  170.  
  171. # In[7]:
  172.  
  173.  
  174. # Train the model
  175.  
  176. adjust_learning_rate(optimizer, 0.0003)
  177.  
  178. model.train()
  179. total_step = len(train_loader)
  180. for epoch in range(num_epochs):
  181. start = time.time()
  182. for i, (img64, img128) in enumerate(train_loader):
  183. img64, img128 = img_preprocess(img64).to(device), img128.to(device)
  184.  
  185. # Forward pass
  186. outputs = model(img64)
  187. loss = criterion(outputs, img128)
  188.  
  189. # Backward and optimize
  190. optimizer.zero_grad()
  191. loss.backward()
  192. optimizer.step()
  193.  
  194. # epoch end
  195. end = time.time()
  196. with torch.no_grad():
  197. valid_loss = 0
  198. for _, (valid64, valid128) in enumerate(valid_loader):
  199. valid64, valid128 = img_preprocess(valid64).to(device), valid128.to(device)
  200. valid_outputs = model(valid64)
  201. valid_loss += criterion(valid_outputs, valid128).item()
  202. print("Epoch [{}/{}], Step [{}/{}] loss: {:.2f} E[Loss]: {:.2f} Time/epoch: {:.2f} secs"
  203. .format(epoch+1, num_epochs, i+1, total_step, loss.item(), valid_loss*2, end-start))
  204.  
  205. # Save the model checkpoint
  206. torch.save(model.state_dict(), 'SRCNN2.pt')
Add Comment
Please, Sign In to add comment