Advertisement
Guest User

Untitled

a guest
Jun 12th, 2019
136
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
Python 11.87 KB | None | 0 0
  1. import torch
  2. import torchvision
  3. import torchvision.transforms as transforms
  4. import matplotlib.pyplot as plt
  5. import numpy as np
  6. import torch.nn as nn
  7. import torch.nn.functional as F
  8. import torch.optim as optim
  9. from torch.utils.data import Dataset
  10.  
  11.  
  12. class Net(nn.Module):
  13.     def __init__(self):
  14.         super(Net, self).__init__()
  15.         self.conv1 = nn.Conv2d(3, 6, 5, 2, 2)
  16.         self.pool = nn.MaxPool2d(3, 2)
  17.         self.conv2 = nn.Conv2d(6, 16, 5, 2, 2)
  18.         self.fc1 = nn.Linear(784, 240)
  19.         self.fc2 = nn.Linear(240, 84)
  20.         self.fc3 = nn.Linear(84, 12)
  21.         self.dropout = nn.Dropout(p=0.5)
  22.  
  23.     def forward(self, x):
  24.         x = self.pool(F.relu(self.conv1(x)))
  25.         x = self.pool(F.relu(self.conv2(x)))
  26.         x = x.view(-1, self.num_flat_features(x))
  27.         x = self.dropout(x)
  28.         x = F.relu(self.fc1(x))
  29.         x = F.relu(self.fc2(x))
  30.         x = self.fc3(x)
  31.         return x
  32.  
  33.     def num_flat_features(self, x):
  34.         size = x.size()[1:]  # all dimensions except the batch dimension
  35.         num_features = 1
  36.         for s in size:
  37.             num_features *= s
  38.         return num_features
  39.  
  40.  
  41. class Memory(Dataset):
  42.     def __init__(self, dataset_array, dataset_labels):
  43.         self.labels = dataset_labels.astype(np.float64)
  44.         self.images = dataset_array.astype(np.float64)
  45.     def __len__(self):
  46.         return self.images.shape[0]
  47.     def __getitem__(self, idx):
  48.         return self.images[idx], self.labels[idx]
  49.  
  50.    
  51. def dataloader(batch_size, num_workers=0):
  52. #    transform0 = transforms.Compose(
  53. #    [transforms.Resize((128, 128)),
  54. #     transforms.ToTensor()
  55. #     ])
  56.    
  57. #    datasetmeanstd = torchvision.datasets.ImageFolder(root='TRUNK12', transform=transform0)
  58. #    mean, std = calculate_img_stats_full(datasetmeanstd)
  59. #    mean = mean.numpy()
  60. #    std = std.numpy()
  61.     mean, std = [(0.56337297, 0.5472399, 0.5224609), (0.13041796, 0.13301197, 0.139214)]
  62.    
  63.     train_transform = transforms.Compose(
  64.     [transforms.Resize((128, 128)),
  65.      transforms.ToTensor(),
  66.      transforms.Normalize(mean, std)
  67.      ])
  68.    
  69.     test_transform = transforms.Compose(
  70.     [transforms.Resize((128, 128)),
  71.      transforms.ToTensor(),
  72.      transforms.Normalize(mean, std)
  73.      ])
  74.    
  75.     print("Data loading...")
  76.     trainset = torchvision.datasets.ImageFolder(root='train', transform=train_transform)
  77.     trainset_length = trainset.__len__()
  78.     trainloader = torch.utils.data.DataLoader(trainset, batch_size=trainset_length, shuffle=False, num_workers=num_workers, pin_memory=False)
  79.     cvset = torchvision.datasets.ImageFolder(root='cv', transform=test_transform)
  80.     cvset_length = cvset.__len__()
  81.     cvloader = torch.utils.data.DataLoader(cvset, batch_size=cvset_length, shuffle=False, num_workers=num_workers, pin_memory=False)
  82.     testset = torchvision.datasets.ImageFolder(root='test', transform=test_transform)
  83.     testset_length = testset.__len__()
  84.     testloader = torch.utils.data.DataLoader(testset, batch_size=testset_length, shuffle=False, num_workers=num_workers, pin_memory=False)
  85.     testloader2 = torch.utils.data.DataLoader(testset, batch_size=4, shuffle=True, num_workers=num_workers, pin_memory=False)
  86.  
  87.     trainset_array = next(iter(trainloader))[0].numpy()
  88.     trainset_labels = next(iter(trainloader))[1].numpy()
  89.     trainset_m = Memory(trainset_array, trainset_labels)
  90.     trainloader_m = torch.utils.data.DataLoader(trainset_m, batch_size=batch_size, shuffle=True, num_workers=num_workers, pin_memory=False)
  91.    
  92.     cvset_array = next(iter(cvloader))[0].numpy()
  93.     cvset_labels = next(iter(cvloader))[1].numpy()
  94.     cvset_m = Memory(cvset_array, cvset_labels)
  95.     cvloader_m = torch.utils.data.DataLoader(cvset_m, batch_size=batch_size, shuffle=False, num_workers=num_workers, pin_memory=False)
  96.    
  97.     testset_array = next(iter(testloader))[0].numpy()
  98.     testset_labels = next(iter(testloader))[1].numpy()
  99.     testset_m = Memory(testset_array, testset_labels)
  100.     testloader_m = torch.utils.data.DataLoader(testset_m, batch_size=batch_size, shuffle=False, num_workers=num_workers, pin_memory=False)
  101.     print("Data loaded.")
  102.     return trainloader_m, cvloader_m, testloader_m, testloader2
  103.  
  104.  
  105. def calculate_img_stats_full(dataset):
  106.     imgs_ = torch.stack([img for img,_ in dataset],dim=3)
  107.     imgs_ = imgs_.view(3,-1)
  108.     imgs_mean = imgs_.mean(dim=1)
  109.     imgs_std = imgs_.std(dim=1)
  110.     return imgs_mean, imgs_std
  111.  
  112.  
  113. def training(trainloader, cvloader, testloader, net, opt, epochs, lr, weight_decay, step_size, gamma):
  114.     criterion = nn.CrossEntropyLoss()
  115.     adam = optim.Adam(net.parameters(), lr=lr, betas=(0.9, 0.999), eps=1e-08, weight_decay=weight_decay, amsgrad=False)
  116.     adagrad = optim.Adagrad(net.parameters(), lr=lr, lr_decay=0, weight_decay=weight_decay, initial_accumulator_value=0)
  117.     sgd = optim.SGD(net.parameters(), lr=lr, momentum=0.9, weight_decay=weight_decay)
  118.     if opt == 1:
  119.         optimizer = adam
  120.     elif opt == 2:
  121.         optimizer = adagrad
  122.     else:
  123.         optimizer = sgd
  124.     scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=step_size, gamma=gamma)
  125.  
  126.     print("Training started...")
  127.     for epoch in range(epochs):  # loop over the dataset multiple times
  128.  
  129.         running_loss = 0.0
  130.         for i, data in enumerate(trainloader, 0):
  131.             # get the inputs
  132.             inputs, labels = data
  133.             inputs = inputs.float().to(device)
  134.             labels = labels.long().to(device)
  135.  
  136.             # zero the parameter gradients
  137.             optimizer.zero_grad()
  138.  
  139.             # forward + backward + optimize
  140.             outputs = net(inputs)
  141.             loss = criterion(outputs, labels)
  142.             loss.backward()
  143.             optimizer.step()
  144.  
  145.             # print statistics
  146.             running_loss += loss.item()
  147.  
  148.         print('Epoch [%d / %d], Loss: %.10f' % (epoch + 1, epochs, running_loss))
  149.  
  150.         if epoch % 20 == 19:
  151.             print('Accuracy of the network on the train images: %d %%' % accshow(trainloader, net))
  152.             print('Accuracy of the network on the cross-validation images: %d %%' % accshow(cvloader, net))
  153.             print('Accuracy of the network on the test images: %d %%' % accshow(testloader, net))
  154.  
  155.         scheduler.step()
  156.     print('Training finished.')
  157.    
  158.    
  159. def checking(trainloader, cvloader, testloader, epochs, step_size, loops):
  160.     cv_score, cv_train, cv_test, cv_opt, cv_rate, cv_decay, test_score, test_train, test_cv, test_opt, test_rate, test_decay = [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]
  161.     opt_list = [1, 2, 3]
  162.     rate_list = [0.001, 0.003, 0.005, 0.01, 0.03, 0.05]
  163.     print("Checking started...")
  164.     for l in rate_list:
  165.         for w in rate_list:
  166.             for o in opt_list:
  167.                 for i in range(loops):
  168.                     net = Net()
  169.                     net = net.to(device)    
  170.                     training(trainloader, cvloader, testloader, net, o, epochs, l, w, step_size, gamma=0.1)
  171.                     train_acc = accshow(trainloader, net)
  172.                     cv_acc = accshow(cvloader, net)
  173.                     test_acc = accshow(testloader, net)
  174.                     print("Train: %d %%. CV: %d %%. Test: %d %%. Opt: %d. Rate: %.3f. Decay: %.3f. Loop: %d" % (train_acc, cv_acc, test_acc, o, l, w, i+1))
  175.                     if cv_acc > cv_score:
  176.                         cv_score = cv_acc
  177.                         cv_train = train_acc
  178.                         cv_test = test_acc
  179.                         cv_opt = o
  180.                         cv_rate = l
  181.                         cv_decay = w
  182.                         cv_net = net
  183.                     if test_acc > test_score:
  184.                         test_score = test_acc
  185.                         test_train = train_acc
  186.                         test_cv = cv_acc
  187.                         test_opt = o
  188.                         test_rate = l
  189.                         test_decay = w
  190.                         test_net = net
  191.     print("Checking finished.")
  192.     print("Best cross-validation set accuracy:")
  193.     print("Train: %d %%. CV: %d %%. Test: %d %%. Opt: %d. Rate: %.3f. Decay: %.3f." % (cv_train, cv_score, cv_test, cv_opt, cv_rate, cv_decay))
  194.     print("Best test set accuracy:")
  195.     print("Train: %d %%. CV: %d %%. Test: %d %%. Opt: %d. Rate: %.3f. Decay: %.3f." % (test_train, test_cv, test_score, test_opt, test_rate, test_decay))
  196.    
  197.     cv_net = cv_net.to(device)
  198.     print("Checking best cross-validation set.")
  199.     print('Accuracy of the network on the train images: %d %%' % accshow(trainloader, cv_net))
  200.     print('Accuracy of the network on the cross-validation images: %d %%' % accshow(cvloader, cv_net))
  201.     print('Accuracy of the network on the test images: %d %%' % accshow(testloader, cv_net))
  202.     test_net = test_net.to(device)
  203.     print("Checking best test set.")
  204.     print('Accuracy of the network on the train images: %d %%' % accshow(trainloader, test_net))
  205.     print('Accuracy of the network on the cross-validation images: %d %%' % accshow(cvloader, test_net))
  206.     print('Accuracy of the network on the test images: %d %%' % accshow(testloader, test_net))
  207.  
  208.    
  209. def accshow(loader, net):
  210.     correct = 0
  211.     total = 0
  212.     with torch.no_grad():
  213.         for data in loader:
  214.             inputs, labels = data
  215.             inputs = inputs.float().to(device)
  216.             labels = labels.long().to(device)
  217.             outputs = net(inputs)
  218.             _, predicted = torch.max(outputs.data, 1)
  219.             total += labels.size(0)
  220.             correct += (predicted == labels).sum().item()  
  221.     return (100 * correct / total)
  222.  
  223.  
  224. def classaccshow(loader, net, classes, classes_length):  
  225.     class_correct = list(0. for i in range(classes_length))
  226.     class_total = list(0. for i in range(classes_length))
  227.     with torch.no_grad():
  228.         for data in loader:
  229.             inputs, labels = data
  230.             inputs = inputs.float().to(device)
  231.             labels = labels.long().to(device)
  232.             outputs = net(inputs)
  233.             _, predicted = torch.max(outputs, 1)
  234.             c = (predicted == labels).squeeze()
  235.             try:
  236.                 for i in range(12):
  237.                     label = labels[i]
  238.                     class_correct[label] += c[i].item()
  239.                     class_total[label] += 1
  240.             except IndexError:
  241.                 pass  
  242.     for i in range(classes_length):
  243.         print('Accuracy of %5s : %2d %%' % (classes[i], 100 * class_correct[i] / class_total[i]))
  244.  
  245.  
  246. def imshow(img):
  247.     img = img * 0.13 + 0.5     # unnormalize
  248.     npimg = img.numpy()
  249.     plt.imshow(np.transpose(npimg, (1, 2, 0)))
  250.     plt.show()
  251.  
  252.  
  253. def imgshow(loader, net, classes):
  254.     dataiter = iter(loader)
  255.     inputs, labels = dataiter.next()
  256.     imshow(torchvision.utils.make_grid(inputs))
  257.     print('GroundTruth: ', ' '.join('%5s' % classes[labels[j]] for j in range(4)))  
  258.     inputs = inputs.to(device)
  259.     outputs = net(inputs)  
  260.     _, predicted = torch.max(outputs, 1)  
  261.     print('Predicted: ', ' '.join('%5s' % classes[predicted[j]] for j in range(4)))
  262.  
  263.  
  264.  
  265. if __name__ == '__main__':
  266.  
  267.     device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
  268.    
  269.     classes = ('alder', 'beech', 'birch', 'chestnut', 'gingko biloba', 'hornbeam', 'horse chestnut', 'linden', 'oak', 'oriental plane', 'pine', 'spruce')
  270.     classes_length = len(classes)    
  271.     batch_size = 60
  272.     trainloader, cvloader, testloader, testloader2 = dataloader(batch_size)
  273.  
  274.     net = Net()
  275.     net = net.to(device)    
  276.     epochs = 500
  277.     step_size = 400
  278.     opt = 1
  279.     learning_rate = 0.003
  280.     weight_decay = 0.03
  281.    
  282.     training(trainloader, cvloader, testloader, net, opt, epochs, learning_rate, weight_decay, step_size, gamma=0.1)
  283.    
  284. #    checking(trainloader, cvloader, testloader, epochs, step_size, loops=1)
  285.        
  286. #    classaccshow(cvloader, net, classes, classes_length)
  287. #    classaccshow(testloader, net, classes, classes_length)
  288.     imgshow(testloader2, net, classes)
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement