Guest User

Untitled

a guest
Mar 23rd, 2017
339
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
  1. import torch
  2. import torch.nn as nn
  3. import torch.nn.functional as F
  4. from torch.nn.modules import Module
  5. from torch.nn.parameter import Parameter
  6. from torch.autograd import Variable
  7.  
  8. import torchvision
  9. import torchvision.transforms as transforms
  10.  
  11.  
  12. # The output of torchvision datasets are PILImage images of range [0, 1].
  13. # We transform them to Tensors of normalized range [-1, 1]
  14. transform=transforms.Compose([transforms.ToTensor(),
  15.                               transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
  16.                              ])
  17. trainset = torchvision.datasets.CIFAR10(root='./data', train=True, download=True, transform=transform)
  18. trainloader = torch.utils.data.DataLoader(trainset, batch_size=4,
  19.                                           shuffle=True, num_workers=2)
  20.  
  21. testset = torchvision.datasets.CIFAR10(root='./data', train=False, download=True, transform=transform)
  22. testloader = torch.utils.data.DataLoader(testset, batch_size=4,
  23.                                           shuffle=False, num_workers=2)
  24. classes = ('plane', 'car', 'bird', 'cat',
  25.            'deer', 'dog', 'frog', 'horse', 'ship', 'truck')
  26.  
  27. class BatchReNorm1d(Module):
  28.  
  29.     def __init__(self, num_features, eps=1e-5, momentum=0.1, affine=True, rmax=3.0, dmax=5.0):
  30.         super(BatchReNorm1d, self).__init__()
  31.         self.num_features = num_features
  32.         self.eps = eps
  33.         self.momentum = momentum
  34.         self.affine = affine
  35.         self.rmax = rmax
  36.         self.dmax = dmax
  37.  
  38.         if self.affine:
  39.             self.weight = Parameter(torch.Tensor(num_features))
  40.             self.bias = Parameter(torch.Tensor(num_features))
  41.         else:
  42.             self.register_parameter('weight', None)
  43.             self.register_parameter('bias', None)
  44.  
  45.         self.register_buffer('running_mean', torch.zeros(num_features))
  46.         self.register_buffer('running_var', torch.ones(num_features))
  47.         self.register_buffer('r', torch.ones(num_features))
  48.         self.register_buffer('d', torch.zeros(num_features))
  49.  
  50.         self.reset_parameters()
  51.  
  52.     def reset_parameters(self):
  53.         self.running_mean.zero_()
  54.         self.running_var.fill_(1)
  55.         self.r.fill_(1)
  56.         self.d.zero_()
  57.         if self.affine:
  58.             self.weight.data.uniform_()
  59.             self.bias.data.zero_()
  60.  
  61.     def _check_input_dim(self, input):
  62.         if input.size(1) != self.running_mean.nelement():
  63.             raise ValueError('got {}-feature tensor, expected {}'
  64.                              .format(input.size(1), self.num_features))
  65.  
  66.     def forward(self, input):
  67.         self._check_input_dim(input)
  68.         n = input.size()[0]
  69.  
  70.         if self.training:
  71.             mean = torch.mean(input, dim=0)
  72.  
  73.             sum = torch.sum((input - mean.expand_as(input))**2, dim=0)
  74.             if sum == 0 and self.eps == 0:
  75.                 invstd = 0.0
  76.             else:
  77.                 invstd = 1./torch.sqrt(sum/n + self.eps)
  78.             unbiased_var = sum/(n - 1)
  79.  
  80.  
  81.             self.r = torch.clamp(torch.sqrt(unbiased_var).data / torch.sqrt(self.running_var),
  82.                             1./self.rmax, self.rmax)
  83.             self.d = torch.clamp((mean.data - self.running_mean)/ torch.sqrt(self.running_var),
  84.                             -self.dmax, self.dmax)
  85.  
  86.             r = Variable(self.r, requires_grad=False).expand_as(input)
  87.             d = Variable(self.d, requires_grad=False).expand_as(input)
  88.  
  89.             input_normalized = (input - mean.expand_as(input)) * invstd.expand_as(input)
  90.  
  91.             input_normalized = input_normalized*r + d
  92.  
  93.             self.running_mean += self.momentum * (mean.data - self.running_mean)
  94.             self.running_var += self.momentum * (unbiased_var.data - self.running_var)
  95.  
  96.             if not self.affine:
  97.                 return input_normalized
  98.  
  99.             output = input_normalized * self.weight.expand_as(input)
  100.             output += self.bias.unsqueeze(0).expand_as(input)
  101.  
  102.             return output
  103.  
  104.         else:
  105.             mean = Variable(self.running_mean).expand_as(input)
  106.             invstd  = 1./ torch.sqrt(Variable(self.running_var).expand_as(input) + self.eps)
  107.  
  108.             input_normalized = (input - mean.expand_as(input)) * invstd.expand_as(input)
  109.  
  110.             if not self.affine:
  111.                 return input_normalized
  112.  
  113.             output = input_normalized * self.weight.expand_as(input)
  114.             output += self.bias.unsqueeze(0).expand_as(input)
  115.  
  116.             return output
  117.  
  118.     def __repr__(self):
  119.         return ('{name}({num_features}, eps={eps}, momentum={momentum},'
  120.                 'affine={affine}, rmax={rmax}, dmax={dmax})'
  121.                 .format(name=self.__class__.__name__, **self.__dict__))
  122.  
  123.  
  124. class NetBatchReNorm(Module):
  125.     def __init__(self):
  126.         super(NetBatchReNorm, self).__init__()
  127.         self.conv1 = nn.Conv2d(3, 6, 5)
  128.         self.bn2d_1 = nn.BatchNorm2d(6)
  129.         self.pool  = nn.MaxPool2d(2,2)
  130.  
  131.         self.conv2 = nn.Conv2d(6, 16, 5)
  132.         self.bn2d_2   = nn.BatchNorm2d(16)
  133.  
  134.         self.fc1   = nn.Linear(16*5*5, 120)
  135.         self.bn1d_1   = BatchReNorm1d(120)
  136.  
  137.  
  138.         self.fc2   = nn.Linear(120, 84)
  139.         self.bn1d_2 = BatchReNorm1d(84, rmax=1.0, dmax=0.0)
  140.  
  141.         self.fc3   = nn.Linear(84, 10)
  142.  
  143.     def forward(self, x):
  144.         x = self.pool(self.bn2d_1(F.relu(self.conv1(x))))
  145.         x = self.pool(self.bn2d_2(F.relu(self.conv2(x))))
  146.         x = x.view(-1, 16*5*5)
  147.         x = self.bn1d_1(F.relu(self.fc1(x)))
  148.         x = self.bn1d_2(F.relu(self.fc2(x)))
  149.         x = self.fc3(x)
  150.         return x
  151.  
  152. net2 = NetBatchReNorm()
  153.  
  154. import torch.optim as optim
  155. criterion = nn.CrossEntropyLoss() # use a Classification Cross-Entropy loss
  156. optimizer = optim.SGD(net2.parameters(), lr=0.001, momentum=0.9)
  157.  
  158. loss_no_batchnorm = []
  159. for epoch in range(1): # loop over the dataset multiple times
  160.     running_loss = 0.0
  161.     for i, data in enumerate(trainloader, 0):
  162.  
  163.         # get the inputs
  164.         inputs, labels = data
  165.         # wrap them in Variable
  166.         inputs, labels = Variable(inputs), Variable(labels)
  167.         # zero the parameter gradients
  168.         optimizer.zero_grad()
  169.  
  170.         # forward + backward + optimize
  171.         outputs = net2(inputs)
  172.         loss = criterion(outputs, labels)
  173.         loss.backward()
  174.         optimizer.step()
  175.  
  176.         # print statistics
  177.         running_loss += loss.data[0]
  178.         if i % 2000 == 1999: # print every 2000 mini-batches
  179.             print('[%d, %5d] loss: %.3f' % (epoch+1, i+1, running_loss / 2000))
  180.             loss_no_batchnorm.append(running_loss / 2000)
  181.             running_loss = 0.0
  182. print('Finished Training')
  183.  
  184. net2.eval()
  185.  
  186. correct = 0
  187. total = 0
  188. for data in testloader:
  189.     images, labels = data
  190.     outputs = net2(Variable(images))
  191.     _, predicted = torch.max(outputs.data, 1)
  192.     total += labels.size(0)
  193.     correct += (predicted == labels).sum()
  194.  
  195. print('Accuracy of the network on the 10000 test images: %d %%' % (100 * correct / total))
RAW Paste Data Copied