Advertisement
Guest User

Untitled

a guest
May 27th, 2019
143
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
Python 3.72 KB | None | 0 0
  1. import torchvision
  2. import torchvision.transforms as transforms
  3. import torch
  4. from torch import nn
  5. import torch.optim as optim
  6. import pickle as pkl
  7. import numpy as np
  8.  
  9.  
  10. def shrink(X, old_size, quotient):
  11.     X_s = np.stack([np.reshape(x, (old_size, old_size)).reshape(old_size // quotient, quotient, -1, 2)
  12.                    .swapaxes(1, quotient)
  13.                    .reshape(-1, quotient, quotient) for x in X])
  14.     X_s = np.sum(X_s, axis=2)
  15.     X_s = np.sum(X_s, axis=2)
  16.     return X_s / (quotient * 2)
  17.  
  18.  
  19. if __name__ == '__main__':
  20.     # device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
  21.  
  22.     data = pkl.load(open('train.pkl', mode='rb'))
  23.     data[0] = shrink(data[0], 36, 2).astype(np.float32)
  24.  
  25.     traindata = data[0][:50000]
  26.  
  27.     testdata = data[0][50000:]
  28.     train_labels = data[1][:50000]
  29.     test_labels = data[1][50000:]
  30.     tensor_traindata = torch.from_numpy(traindata)
  31.     tensor_train_labels = torch.from_numpy(train_labels)
  32.  
  33.     tensor_testdata = torch.from_numpy(testdata)
  34.     tensot_test_labels = torch.from_numpy(test_labels)
  35.  
  36.     train_dataset = torch.utils.data.TensorDataset(tensor_traindata, tensor_train_labels)
  37.     test_dataset = torch.utils.data.TensorDataset(tensor_testdata, tensot_test_labels)
  38.  
  39.     trainloader = torch.utils.data.DataLoader(train_dataset, batch_size=16,
  40.                                               shuffle=True, num_workers=2)
  41.     testloader = torch.utils.data.DataLoader(test_dataset, batch_size=16,
  42.                                              shuffle=False, num_workers=2)
  43.  
  44.     torch.multiprocessing.freeze_support()
  45.  
  46.  
  47.     class SimpleNetwork(nn.Module):
  48.         def __init__(self):
  49.             super(SimpleNetwork, self).__init__()
  50.             self.fc1 = nn.Linear(324, 324)
  51.             self.fc2 = nn.Linear(324, 10)
  52.             self.sig = nn.Sigmoid()
  53.             self.params = list(self.fc1.parameters())
  54.  
  55.         def forward(self, x):
  56.             x = self.fc1(x)
  57.             x = self.sig(x)
  58.             x = self.fc2(x)
  59.             return x
  60.  
  61.         def get_params(self):
  62.             w = self.params
  63.             return w
  64.  
  65.  
  66.     net = SimpleNetwork()
  67.     # net.to(device)
  68.     criterion = nn.CrossEntropyLoss()
  69.     optimizer = optim.SGD(net.parameters(), lr=0.001, momentum=0.9)
  70.  
  71.     print('testing started')
  72.     # net.load_state_dict(torch.load('parameters.pkl'))
  73.  
  74.     for epoch in range(20):
  75.         net.eval()
  76.         running_loss = 0.0
  77.  
  78.         for i, data in enumerate(trainloader, 0):
  79.             inputs, labels = data
  80.             # inputs, labels = inputs.to(device), labels.to(device)
  81.  
  82.             optimizer.zero_grad()
  83.  
  84.             outputs = net(inputs)
  85.             loss = criterion(outputs, labels)
  86.             loss.backward()
  87.             optimizer.step()
  88.             # print(torch.sum(net.fc1.weight.data))
  89.  
  90.             running_loss += loss.item()
  91.             if i % 2000 == 1999:
  92.                 print('[%d, %5d] loss: %.3f' %
  93.                       (epoch + 1, i + 1, running_loss / 2000))
  94.                 running_loss = 0.0
  95.  
  96.     data = net.fc1.weight.data.numpy(), net.fc2.weight.data.numpy(), net.fc1.bias.data.numpy(), net.fc2.bias.data.numpy()
  97.     with open('weights.pkl', 'wb') as f:
  98.         pkl.dump(data, f)
  99.  
  100.     torch.save(net.state_dict(), 'parameters.pkl')
  101.     print('finished testing')
  102.  
  103.     correct = 0
  104.     total = 0
  105.     with torch.no_grad():
  106.         for data in testloader:
  107.             images, labels = data
  108.             outputs = net(images)
  109.             _, predicted = torch.max(outputs.data, 1)
  110.             total += labels.size(0)
  111.             correct += (predicted == labels).sum().item()
  112.  
  113.     print('Accuracy of the network on the 10000 test images: %d %%' % (
  114.             100 * correct / total))
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement