Advertisement
Guest User

Untitled

a guest
Jan 29th, 2020
355
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
Python 2.97 KB | None | 0 0
  1. from __future__ import print_function, division
  2. import os
  3. from torch import nn, optim, save, load
  4. import pandas as pd
  5. from skimage import io, transform
  6. import numpy as np
  7. import matplotlib.pyplot as plt
  8. from torch.utils.data import DataLoader
  9. from torchvision import transforms, utils
  10. import torch.nn.functional as F
  11. from torch.autograd import Variable
  12.  
  13. from dataset import DenoisedImageDataset
  14.  
  15. BATCH_SIZE = 20
  16. INPUT_IMG_SIZE = 128 * 128 * 3
  17. OUTPUT_IMG_SIZE = 128 * 128 * 3
  18.  
  19. data = DenoisedImageDataset()
  20. data_loader = DataLoader(data, batch_size=BATCH_SIZE, shuffle=True)
  21.  
  22.  
  23. class Denoiser(nn.Module):
  24.     def __init__(self):
  25.         super(Denoiser, self).__init__()
  26.  
  27.         self.hidden0 = nn.Linear(INPUT_IMG_SIZE, 4096)
  28.  
  29.         self.hidden1 = nn.Linear(4096, 1024)
  30.         self.hidden2 = nn.Linear(1024, 512)
  31.         self.hidden3 = nn.Linear(512, 256)
  32.         self.hidden4 = nn.Linear(256, 128)
  33.         self.hidden5 = nn.Linear(128, 64)
  34.         self.hidden6 = nn.Linear(64, 256)
  35.         self.hidden7 = nn.Linear(256, 512)
  36.         self.hidden8 = nn.Linear(512, 1024)
  37.  
  38.         self.out = nn.Linear(1024, OUTPUT_IMG_SIZE)
  39.  
  40.     def forward(self, x):
  41.         x = x.float()
  42.         batch = x.size()[0]
  43.         x = x.view(batch, INPUT_IMG_SIZE)
  44.         x = F.relu(self.hidden0(x))
  45.         x = F.relu(self.hidden1(x))
  46.         x = F.relu(self.hidden2(x))
  47.         x = F.relu(self.hidden3(x))
  48.         x = F.relu(self.hidden4(x))
  49.         x = F.relu(self.hidden5(x))
  50.         x = F.relu(self.hidden6(x))
  51.         x = F.relu(self.hidden7(x))
  52.         x = F.relu(self.hidden8(x))
  53.  
  54.         x = F.sigmoid(self.out(x))
  55.  
  56.         x = x.view(batch, 3, 128, 128)
  57.  
  58.         # x = F.relu(self.t_conv1(x))
  59.         # x = F.relu(self.t_conv2(x))
  60.         # x = F.relu(self.t_conv3(x))
  61.         # x = F.tanh(self.conv_out(x))
  62.  
  63.         return x
  64.  
  65.  
  66. loss = nn.MSELoss()
  67. denoiser = Denoiser()
  68. optimizer = optim.Adam(denoiser.parameters(), lr=0.002)
  69.  
  70. update_lr = lambda epoch: 0.9 ** epoch
  71.  
  72. lr_scheduler = optim.lr_scheduler.LambdaLR(optimizer, lr_lambda=[update_lr])
  73.  
  74.  
  75. def train_denoiser(noised_image, denoised_image, model, save_image=False, index=0):
  76.     prediction = model(noised_image)
  77.     optimizer.zero_grad()
  78.     error = loss(prediction, denoised_image)
  79.     error.backward()
  80.     optimizer.step()
  81.     if save_image:
  82.         print(error)
  83.         utils.save_image(prediction[0], 'output/' + str(index) + '.jpg')
  84.  
  85.  
  86. print('Begin training model on sample of size', len(data_loader))
  87. print('Treated :', end='')
  88. for epoch in range(1000):
  89.     for n_batch, data in enumerate(data_loader):
  90.         noised_image = data['input']
  91.         denoised_image = data['output']
  92.         batch_size = len(data['input'])
  93.         train_denoiser(noised_image, denoised_image, denoiser, True, str(epoch) + '.' + str(n_batch))
  94.         if n_batch % 5 == 0:
  95.             print(epoch, n_batch, optimizer.param_groups[0]['lr'], end=', ')
  96.     lr_scheduler.step()
  97.  
  98. save(denoiser.state_dict(), 'saved_model')
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement