Advertisement
Not a member of Pastebin yet?
Sign Up,
it unlocks many cool features!
- import torchvision
- import torchvision.transforms as transforms
- import torch
- from torch import nn
- import torch.optim as optim
- import pickle as pkl
- import numpy as np
- def shrink(X, old_size, quotient):
- X_s = np.stack([np.reshape(x, (old_size, old_size)).reshape(old_size // quotient, quotient, -1, 2)
- .swapaxes(1, quotient)
- .reshape(-1, quotient, quotient) for x in X])
- X_s = np.sum(X_s, axis=2)
- X_s = np.sum(X_s, axis=2)
- return X_s / (quotient * 2)
- if __name__ == '__main__':
- # device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
- data = pkl.load(open('train.pkl', mode='rb'))
- data[0] = shrink(data[0], 36, 2).astype(np.float32)
- traindata = data[0][:50000]
- testdata = data[0][50000:]
- train_labels = data[1][:50000]
- test_labels = data[1][50000:]
- tensor_traindata = torch.from_numpy(traindata)
- tensor_train_labels = torch.from_numpy(train_labels)
- tensor_testdata = torch.from_numpy(testdata)
- tensot_test_labels = torch.from_numpy(test_labels)
- train_dataset = torch.utils.data.TensorDataset(tensor_traindata, tensor_train_labels)
- test_dataset = torch.utils.data.TensorDataset(tensor_testdata, tensot_test_labels)
- trainloader = torch.utils.data.DataLoader(train_dataset, batch_size=16,
- shuffle=True, num_workers=2)
- testloader = torch.utils.data.DataLoader(test_dataset, batch_size=16,
- shuffle=False, num_workers=2)
- torch.multiprocessing.freeze_support()
- class SimpleNetwork(nn.Module):
- def __init__(self):
- super(SimpleNetwork, self).__init__()
- self.fc1 = nn.Linear(324, 324)
- self.fc2 = nn.Linear(324, 10)
- self.sig = nn.Sigmoid()
- self.params = list(self.fc1.parameters())
- def forward(self, x):
- x = self.fc1(x)
- x = self.sig(x)
- x = self.fc2(x)
- return x
- def get_params(self):
- w = self.params
- return w
- net = SimpleNetwork()
- # net.to(device)
- criterion = nn.CrossEntropyLoss()
- optimizer = optim.SGD(net.parameters(), lr=0.001, momentum=0.9)
- print('testing started')
- # net.load_state_dict(torch.load('parameters.pkl'))
- for epoch in range(20):
- net.eval()
- running_loss = 0.0
- for i, data in enumerate(trainloader, 0):
- inputs, labels = data
- # inputs, labels = inputs.to(device), labels.to(device)
- optimizer.zero_grad()
- outputs = net(inputs)
- loss = criterion(outputs, labels)
- loss.backward()
- optimizer.step()
- # print(torch.sum(net.fc1.weight.data))
- running_loss += loss.item()
- if i % 2000 == 1999:
- print('[%d, %5d] loss: %.3f' %
- (epoch + 1, i + 1, running_loss / 2000))
- running_loss = 0.0
- data = net.fc1.weight.data.numpy(), net.fc2.weight.data.numpy(), net.fc1.bias.data.numpy(), net.fc2.bias.data.numpy()
- with open('weights.pkl', 'wb') as f:
- pkl.dump(data, f)
- torch.save(net.state_dict(), 'parameters.pkl')
- print('finished testing')
- correct = 0
- total = 0
- with torch.no_grad():
- for data in testloader:
- images, labels = data
- outputs = net(images)
- _, predicted = torch.max(outputs.data, 1)
- total += labels.size(0)
- correct += (predicted == labels).sum().item()
- print('Accuracy of the network on the 10000 test images: %d %%' % (
- 100 * correct / total))
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement