Advertisement
Not a member of Pastebin yet?
Sign Up,
it unlocks many cool features!
- from __future__ import print_function, division
- import os
- from torch import nn, optim, save, load
- import pandas as pd
- from skimage import io, transform
- import numpy as np
- import matplotlib.pyplot as plt
- from torch.utils.data import DataLoader
- from torchvision import transforms, utils
- import torch.nn.functional as F
- from torch.autograd import Variable
- from dataset import DenoisedImageDataset
- BATCH_SIZE = 20
- INPUT_IMG_SIZE = 128 * 128 * 3
- OUTPUT_IMG_SIZE = 128 * 128 * 3
- data = DenoisedImageDataset()
- data_loader = DataLoader(data, batch_size=BATCH_SIZE, shuffle=True)
- class Denoiser(nn.Module):
- def __init__(self):
- super(Denoiser, self).__init__()
- self.hidden0 = nn.Linear(INPUT_IMG_SIZE, 4096)
- self.hidden1 = nn.Linear(4096, 1024)
- self.hidden2 = nn.Linear(1024, 512)
- self.hidden3 = nn.Linear(512, 256)
- self.hidden4 = nn.Linear(256, 128)
- self.hidden5 = nn.Linear(128, 64)
- self.hidden6 = nn.Linear(64, 256)
- self.hidden7 = nn.Linear(256, 512)
- self.hidden8 = nn.Linear(512, 1024)
- self.out = nn.Linear(1024, OUTPUT_IMG_SIZE)
- def forward(self, x):
- x = x.float()
- batch = x.size()[0]
- x = x.view(batch, INPUT_IMG_SIZE)
- x = F.relu(self.hidden0(x))
- x = F.relu(self.hidden1(x))
- x = F.relu(self.hidden2(x))
- x = F.relu(self.hidden3(x))
- x = F.relu(self.hidden4(x))
- x = F.relu(self.hidden5(x))
- x = F.relu(self.hidden6(x))
- x = F.relu(self.hidden7(x))
- x = F.relu(self.hidden8(x))
- x = F.sigmoid(self.out(x))
- x = x.view(batch, 3, 128, 128)
- # x = F.relu(self.t_conv1(x))
- # x = F.relu(self.t_conv2(x))
- # x = F.relu(self.t_conv3(x))
- # x = F.tanh(self.conv_out(x))
- return x
- loss = nn.MSELoss()
- denoiser = Denoiser()
- optimizer = optim.Adam(denoiser.parameters(), lr=0.002)
- update_lr = lambda epoch: 0.9 ** epoch
- lr_scheduler = optim.lr_scheduler.LambdaLR(optimizer, lr_lambda=[update_lr])
- def train_denoiser(noised_image, denoised_image, model, save_image=False, index=0):
- prediction = model(noised_image)
- optimizer.zero_grad()
- error = loss(prediction, denoised_image)
- error.backward()
- optimizer.step()
- if save_image:
- print(error)
- utils.save_image(prediction[0], 'output/' + str(index) + '.jpg')
- print('Begin training model on sample of size', len(data_loader))
- print('Treated :', end='')
- for epoch in range(1000):
- for n_batch, data in enumerate(data_loader):
- noised_image = data['input']
- denoised_image = data['output']
- batch_size = len(data['input'])
- train_denoiser(noised_image, denoised_image, denoiser, True, str(epoch) + '.' + str(n_batch))
- if n_batch % 5 == 0:
- print(epoch, n_batch, optimizer.param_groups[0]['lr'], end=', ')
- lr_scheduler.step()
- save(denoiser.state_dict(), 'saved_model')
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement