Guest User

Untitled

a guest
Jun 1st, 2019
180
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
text 5.90 KB | None | 0 0
  1. import torch
  2. import torchvision
  3. import torchvision.transforms as transforms
  4.  
  5. transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])
  6.  
  7. trainset = torchvision.datasets.CIFAR10(root='./data', train=True,
  8. download=True, transform=transform)
  9. trainloader = torch.utils.data.DataLoader(trainset, batch_size=4,
  10. shuffle=True, num_workers=2)
  11.  
  12. testset = torchvision.datasets.CIFAR10(root='./data', train=False,
  13. download=True, transform=transform)
  14. testloader = torch.utils.data.DataLoader(testset, batch_size=4,
  15. shuffle=False, num_workers=2)
  16.  
  17. classes = ('plane', 'car', 'bird', 'cat',
  18. 'deer', 'dog', 'frog', 'horse', 'ship', 'truck')
  19.  
  20. import torch.nn as nn
  21. import torch.nn.functional as F
  22.  
  23. def make_sparse(in_dim, out_dim, size, mode):
  24. # Always compress the more compressible dimension
  25. invert = False
  26. if out_dim > in_dim:
  27. invert = True
  28. out_dim, in_dim = in_dim, out_dim
  29.  
  30. mask = torch.zeros(out_dim, in_dim)
  31. assert(mode in ['Expander', 'Group'])
  32.  
  33. if mode == 'Expander':
  34. minstd = -1
  35. for _ in range(20):
  36. m = torch.zeros(out_dim, in_dim)
  37. for i in range(out_dim):
  38. x = torch.randperm(in_dim)
  39. m[i][x[:size]] = 1
  40. std = m.sum(dim=0).std()
  41. if minstd == -1 or minstd > std:
  42. minstd = std
  43. mask = m
  44.  
  45. elif mode == 'Group':
  46. assert(in_dim%size==0 and out_dim%size==0)
  47. for i in range(out_dim):
  48. for j in range(in_dim):
  49. if (i//size == j//size):
  50. mask[i][j] == 1
  51.  
  52. if invert:
  53. mask = mask.t()
  54.  
  55. return mask
  56.  
  57. class SparseConv2d(torch.nn.Module):
  58. def __init__(self, inWCin, inWCout, kernel_size, stride=1, padding=0, dilation=1, sparse_size=0, sparse_mode='Expander'):
  59. super(SparseConv2d, self).__init__()
  60. self.kernel_size = kernel_size
  61. self.stride = stride
  62. self.padding = padding
  63. self.dilation = dilation
  64. self.out_channels = inWCout
  65.  
  66. mask = make_sparse(in_dim=inWCin, out_dim=inWCout, size=sparse_size, mode=sparse_mode)
  67. weight = torch.zeros((inWCout, inWCin))
  68. weight = 0.01*torch.nn.init.kaiming_normal_(weight)
  69. weight = torch.mul(weight, mask)
  70. weight = weight.unsqueeze(2).unsqueeze(3).repeat(1, 1, kernel_size, kernel_size)
  71. weight = weight.view(weight.size(0), -1)
  72. weight = weight.to_sparse().cuda()
  73. self.weight = torch.nn.Parameter(weight, requires_grad=True)
  74.  
  75. def forward(self, x):
  76. out = (x.size(2)+2*self.padding-self.dilation*(self.kernel_size-1)-1)//self.stride+1
  77. x_unf = torch.nn.functional.unfold(x, (self.kernel_size, self.kernel_size)).transpose(1,2)
  78. x_unf = torch.sparse.mm(self.weight, x_unf.reshape(-1, x_unf.size(2)).t()).t().reshape(x.size(0),-1,self.out_channels).transpose(1,2)
  79. x_unf = x_unf.view(x_unf.size(0), x_unf.size(1), out, out)
  80. return x_unf
  81.  
  82.  
  83. class Net(nn.Module):
  84. def __init__(self):
  85. super(Net, self).__init__()
  86. self.conv1 = nn.Conv2d(3, 6, 5, bias=False)
  87. self.pool = nn.MaxPool2d(2, 2)
  88. self.conv2 = nn.Conv2d(6, 16, 5, bias=False)
  89. self.fc1 = nn.Conv2d(16, 120, 5, bias=False)
  90. self.fc2 = nn.Conv2d(120, 84, 1, bias=False)
  91. self.fc3 = nn.Conv2d(84, 10, 1, bias=False)
  92.  
  93. def forward(self, x):
  94. x = self.pool(F.relu(self.conv1(x)))
  95. x = self.pool(F.relu(self.conv2(x)))
  96. x = F.relu(self.fc1(x))
  97. x = F.relu(self.fc2(x))
  98. x = self.fc3(x)
  99. x = x.view(-1, 10)
  100. return x
  101.  
  102. class Net2(nn.Module):
  103. def __init__(self):
  104. super(Net2, self).__init__()
  105. self.conv1 = nn.Conv2d(3, 6, 5, bias=False)
  106. self.pool = nn.MaxPool2d(2, 2)
  107. self.conv2 = SparseConv2d(6, 16, 5, sparse_size=12)
  108. self.fc1 = SparseConv2d(16, 120, 5, sparse_size=80)
  109. self.fc2 = SparseConv2d(120, 84, 1, sparse_size=64)
  110. self.fc3 = nn.Conv2d(84, 10, 1, bias=False)
  111.  
  112. def forward(self, x):
  113. x = self.pool(F.relu(self.conv1(x)))
  114. x = self.pool(F.relu(self.conv2(x)))
  115. x = F.relu(self.fc1(x))
  116. x = F.relu(self.fc2(x))
  117. x = self.fc3(x)
  118. x = x.view(-1, 10)
  119. return x
  120.  
  121. #net = Net().cuda()
  122. net = Net2().cuda()
  123.  
  124. import torch.optim as optim
  125.  
  126. criterion = nn.CrossEntropyLoss().cuda()
  127. optimizer = optim.SGD(net.parameters(), lr=0.001, momentum=0.9)
  128.  
  129. for epoch in range(2): # loop over the dataset multiple times
  130.  
  131. running_loss = 0.0
  132. for i, data in enumerate(trainloader, 0):
  133. # get the inputs; data is a list of [inputs, labels]
  134. inputs, labels = data
  135. inputs, labels = inputs.cuda(), labels.cuda()
  136.  
  137. # zero the parameter gradients
  138. optimizer.zero_grad()
  139.  
  140. # forward + backward + optimize
  141. outputs = net(inputs)
  142. loss = criterion(outputs, labels)
  143. loss.backward()
  144. optimizer.step()
  145.  
  146. # print statistics
  147. running_loss += loss.item()
  148. if i % 2000 == 1999: # print every 2000 mini-batches
  149. print('[%d, %5d] loss: %.3f' %
  150. (epoch + 1, i + 1, running_loss / 2000))
  151. running_loss = 0.0
  152.  
  153. print('Finished Training')
  154.  
  155. correct = 0
  156. total = 0
  157. with torch.no_grad():
  158. for data in testloader:
  159. images, labels = data
  160. images, labels = images.cuda(), labels.cuda()
  161. outputs = net(images)
  162. _, predicted = torch.max(outputs.data, 1)
  163. total += labels.size(0)
  164. correct += (predicted == labels).sum().item()
  165.  
  166. print('Accuracy of the network on the 10000 test images: %d %%' % (
  167. 100 * correct / total))
Advertisement
Add Comment
Please, Sign In to add comment