Guest User

Untitled

a guest
Mar 23rd, 2017
93
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
Python 6.22 KB | None | 0 0
  1. import torch
  2. import torch.nn as nn
  3. import torch.nn.functional as F
  4. from torch.nn.modules import Module
  5. from torch.nn.parameter import Parameter
  6. from torch.autograd import Variable
  7.  
  8. import torchvision
  9. import torchvision.transforms as transforms
  10.  
  11. # The output of torchvision datasets are PILImage images of range [0, 1].
  12. # We transform them to Tensors of normalized range [-1, 1]
  13. transform=transforms.Compose([transforms.ToTensor(),
  14.                               transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
  15.                              ])
  16. trainset = torchvision.datasets.CIFAR10(root='./data', train=True, download=True, transform=transform)
  17. trainloader = torch.utils.data.DataLoader(trainset, batch_size=4,
  18.                                           shuffle=True, num_workers=2)
  19.  
  20. testset = torchvision.datasets.CIFAR10(root='./data', train=False, download=True, transform=transform)
  21. testloader = torch.utils.data.DataLoader(testset, batch_size=4,
  22.                                           shuffle=False, num_workers=2)
  23. classes = ('plane', 'car', 'bird', 'cat',
  24.            'deer', 'dog', 'frog', 'horse', 'ship', 'truck')
  25.  
  26.  
  27. class BatchReNorm1d(Module):
  28.  
  29.     def __init__(self, num_features, eps=1e-5, momentum=0.1, rmax=3.0, dmax=5.0, affine=True):
  30.         super(BatchReNorm1d, self).__init__()
  31.         self.num_features = num_features
  32.         self.affine = affine
  33.         self.eps = eps
  34.         self.momentum = momentum
  35.         self.rmax = rmax
  36.         self.dmax = dmax
  37.         if self.affine:
  38.             self.weight = Parameter(torch.Tensor(num_features))
  39.             self.bias = Parameter(torch.Tensor(num_features))
  40.         else:
  41.             self.register_parameter('weight', None)
  42.             self.register_parameter('bias', None)
  43.         self.register_buffer('running_mean', torch.zeros(num_features))
  44.         self.register_buffer('running_var', torch.ones(num_features))
  45.         self.register_buffer('r', torch.ones(1))
  46.         self.register_buffer('d', torch.zeros(1))
  47.         self.reset_parameters()
  48.  
  49.     def reset_parameters(self):
  50.         self.running_mean.zero_()
  51.         self.running_var.fill_(1)
  52.         self.r.fill_(1)
  53.         self.d.zero_()
  54.         if self.affine:
  55.             self.weight.data.uniform_()
  56.             self.bias.data.zero_()
  57.  
  58.     def _check_input_dim(self, input):
  59.         if input.size(1) != self.running_mean.nelement():
  60.             raise ValueError('got {}-feature tensor, expected {}'
  61.                              .format(input.size(1), self.num_features))
  62.  
  63.     def _reshape_input(self, input):
  64.         return input
  65.  
  66.     def forward(self, input):
  67.         self._check_input_dim(input)
  68.         if self.training:
  69.             sample_mean = torch.mean(input, dim=0)
  70.             sample_var =  torch.var(input, dim=0)
  71.             self.r = torch.clamp(sample_var.data / self.running_var,
  72.                             1./self.rmax, self.rmax)
  73.             self.d = torch.clamp((sample_mean.data - self.running_mean)/ self.running_var,
  74.                             -self.dmax, self.dmax)
  75.             input_normalized = (input - sample_mean.expand_as(input))/sample_var.expand_as(input)
  76.             input_normalized = input_normalized*Variable(self.r).expand_as(input)
  77.             input_normalized += Variable(self.d).expand_as(input)
  78.             self.running_mean += self.momentum * (sample_mean.data - self.running_mean)
  79.             self.running_var  += self.momentum * (sample_var.data - self.running_var)
  80.             if self.affine:
  81.                 input_normalized = input_normalized * self.weight.expand_as(input)
  82.                 input_normalized += self.bias.unsqueeze(0).expand_as(input)
  83.                 return input_normalized
  84.  
  85.             else:
  86.                 return input_normalized
  87.         else:
  88.             input_normalized = (input - self.running_mean.expand_as(input))/self.running_var.expand_as(input)
  89.             if self.affine:
  90.                 input_normalized = input_normalized * self.weight.expand_as(input)
  91.                 input_normalized += self.bias.unsqueeze(0).expand_as(input)
  92.                 return input_normalized
  93.             else:
  94.                 return input_normalized
  95.  
  96.     def __repr__(self):
  97.         return ('{name}({num_features}, eps={eps}, momentum={momentum},'
  98.                 ' affine={affine})'
  99.                 .format(name=self.__class__.__name__, **self.__dict__))
  100.  
  101. class NetBatchReNorm(Module):
  102.     def __init__(self):
  103.         super(NetBatchReNorm, self).__init__()
  104.         self.conv1 = nn.Conv2d(3, 6, 5)
  105.         self.bn2d_1 = nn.BatchNorm2d(6)
  106.         self.pool  = nn.MaxPool2d(2,2)
  107.  
  108.         self.conv2 = nn.Conv2d(6, 16, 5)
  109.         self.bn2d_2   = nn.BatchNorm2d(16)
  110.  
  111.         self.fc1   = nn.Linear(16*5*5, 120)
  112.         self.bn1d_1   = BatchReNorm1d(120)
  113.  
  114.  
  115.         self.fc2   = nn.Linear(120, 84)
  116.         self.bn1d_2 = BatchReNorm1d(84)
  117.  
  118.         self.fc3   = nn.Linear(84, 10)
  119.  
  120.     def forward(self, x):
  121.         x = self.pool(self.bn2d_1(F.relu(self.conv1(x))))
  122.         x = self.pool(self.bn2d_2(F.relu(self.conv2(x))))
  123.         x = x.view(-1, 16*5*5)
  124.         x = self.bn1d_1(F.relu(self.fc1(x)))
  125.         x = self.bn1d_2(F.relu(self.fc2(x)))
  126.         x = self.fc3(x)
  127.         return x
  128.  
  129. net2 = NetBatchReNorm()
  130.  
  131.  
  132. import torch.optim as optim
  133. criterion = nn.CrossEntropyLoss() # use a Classification Cross-Entropy loss
  134. optimizer = optim.SGD(net2.parameters(), lr=0.001, momentum=0.9)
  135.  
  136.  
  137. loss_no_batchnorm = []
  138. for epoch in range(1): # loop over the dataset multiple times
  139.     running_loss = 0.0
  140.     for i, data in enumerate(trainloader, 0):
  141.         # get the inputs
  142.         inputs, labels = data
  143.         # wrap them in Variable
  144.         inputs, labels = Variable(inputs), Variable(labels)
  145.         # zero the parameter gradients
  146.         optimizer.zero_grad()
  147.         # forward + backward + optimize
  148.         outputs = net2(inputs)
  149.         loss = criterion(outputs, labels)
  150.         loss.backward()
  151.         optimizer.step()
  152.  
  153.         # print statistics
  154.         running_loss += loss.data[0]
  155.         if i % 2000 == 1999: # print every 2000 mini-batches
  156.             print('[%d, %5d] loss: %.3f' % (epoch+1, i+1, running_loss / 2000))
  157.             loss_no_batchnorm.append(running_loss / 2000)
  158.             running_loss = 0.0
  159. print('Finished Training')
Add Comment
Please, Sign In to add comment