Not a member of Pastebin yet?
Sign Up,
it unlocks many cool features!
- import torch
- import torch.nn as nn
- import torch.nn.functional as F
- from torch.nn.modules import Module
- from torch.nn.parameter import Parameter
- from torch.autograd import Variable
- import torchvision
- import torchvision.transforms as transforms
- # The output of torchvision datasets are PILImage images of range [0, 1].
- # We transform them to Tensors of normalized range [-1, 1]
- transform=transforms.Compose([transforms.ToTensor(),
- transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
- ])
- trainset = torchvision.datasets.CIFAR10(root='./data', train=True, download=True, transform=transform)
- trainloader = torch.utils.data.DataLoader(trainset, batch_size=4,
- shuffle=True, num_workers=2)
- testset = torchvision.datasets.CIFAR10(root='./data', train=False, download=True, transform=transform)
- testloader = torch.utils.data.DataLoader(testset, batch_size=4,
- shuffle=False, num_workers=2)
- classes = ('plane', 'car', 'bird', 'cat',
- 'deer', 'dog', 'frog', 'horse', 'ship', 'truck')
- class BatchReNorm1d(Module):
- def __init__(self, num_features, eps=1e-5, momentum=0.1, rmax=3.0, dmax=5.0, affine=True):
- super(BatchReNorm1d, self).__init__()
- self.num_features = num_features
- self.affine = affine
- self.eps = eps
- self.momentum = momentum
- self.rmax = rmax
- self.dmax = dmax
- if self.affine:
- self.weight = Parameter(torch.Tensor(num_features))
- self.bias = Parameter(torch.Tensor(num_features))
- else:
- self.register_parameter('weight', None)
- self.register_parameter('bias', None)
- self.register_buffer('running_mean', torch.zeros(num_features))
- self.register_buffer('running_var', torch.ones(num_features))
- self.register_buffer('r', torch.ones(1))
- self.register_buffer('d', torch.zeros(1))
- self.reset_parameters()
- def reset_parameters(self):
- self.running_mean.zero_()
- self.running_var.fill_(1)
- self.r.fill_(1)
- self.d.zero_()
- if self.affine:
- self.weight.data.uniform_()
- self.bias.data.zero_()
- def _check_input_dim(self, input):
- if input.size(1) != self.running_mean.nelement():
- raise ValueError('got {}-feature tensor, expected {}'
- .format(input.size(1), self.num_features))
- def _reshape_input(self, input):
- return input
- def forward(self, input):
- self._check_input_dim(input)
- if self.training:
- sample_mean = torch.mean(input, dim=0)
- sample_var = torch.var(input, dim=0)
- self.r = torch.clamp(sample_var.data / self.running_var,
- 1./self.rmax, self.rmax)
- self.d = torch.clamp((sample_mean.data - self.running_mean)/ self.running_var,
- -self.dmax, self.dmax)
- input_normalized = (input - sample_mean.expand_as(input))/sample_var.expand_as(input)
- input_normalized = input_normalized*Variable(self.r).expand_as(input)
- input_normalized += Variable(self.d).expand_as(input)
- self.running_mean += self.momentum * (sample_mean.data - self.running_mean)
- self.running_var += self.momentum * (sample_var.data - self.running_var)
- if self.affine:
- input_normalized = input_normalized * self.weight.expand_as(input)
- input_normalized += self.bias.unsqueeze(0).expand_as(input)
- return input_normalized
- else:
- return input_normalized
- else:
- input_normalized = (input - self.running_mean.expand_as(input))/self.running_var.expand_as(input)
- if self.affine:
- input_normalized = input_normalized * self.weight.expand_as(input)
- input_normalized += self.bias.unsqueeze(0).expand_as(input)
- return input_normalized
- else:
- return input_normalized
- def __repr__(self):
- return ('{name}({num_features}, eps={eps}, momentum={momentum},'
- ' affine={affine})'
- .format(name=self.__class__.__name__, **self.__dict__))
- class NetBatchReNorm(Module):
- def __init__(self):
- super(NetBatchReNorm, self).__init__()
- self.conv1 = nn.Conv2d(3, 6, 5)
- self.bn2d_1 = nn.BatchNorm2d(6)
- self.pool = nn.MaxPool2d(2,2)
- self.conv2 = nn.Conv2d(6, 16, 5)
- self.bn2d_2 = nn.BatchNorm2d(16)
- self.fc1 = nn.Linear(16*5*5, 120)
- self.bn1d_1 = BatchReNorm1d(120)
- self.fc2 = nn.Linear(120, 84)
- self.bn1d_2 = BatchReNorm1d(84)
- self.fc3 = nn.Linear(84, 10)
- def forward(self, x):
- x = self.pool(self.bn2d_1(F.relu(self.conv1(x))))
- x = self.pool(self.bn2d_2(F.relu(self.conv2(x))))
- x = x.view(-1, 16*5*5)
- x = self.bn1d_1(F.relu(self.fc1(x)))
- x = self.bn1d_2(F.relu(self.fc2(x)))
- x = self.fc3(x)
- return x
- net2 = NetBatchReNorm()
- import torch.optim as optim
- criterion = nn.CrossEntropyLoss() # use a Classification Cross-Entropy loss
- optimizer = optim.SGD(net2.parameters(), lr=0.001, momentum=0.9)
- loss_no_batchnorm = []
- for epoch in range(1): # loop over the dataset multiple times
- running_loss = 0.0
- for i, data in enumerate(trainloader, 0):
- # get the inputs
- inputs, labels = data
- # wrap them in Variable
- inputs, labels = Variable(inputs), Variable(labels)
- # zero the parameter gradients
- optimizer.zero_grad()
- # forward + backward + optimize
- outputs = net2(inputs)
- loss = criterion(outputs, labels)
- loss.backward()
- optimizer.step()
- # print statistics
- running_loss += loss.data[0]
- if i % 2000 == 1999: # print every 2000 mini-batches
- print('[%d, %5d] loss: %.3f' % (epoch+1, i+1, running_loss / 2000))
- loss_no_batchnorm.append(running_loss / 2000)
- running_loss = 0.0
- print('Finished Training')
Add Comment
Please, Sign In to add comment