SHARE
TWEET

Untitled

a guest Jun 12th, 2019 90 Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
  1. import torch
  2. import torchvision
  3. import torchvision.transforms as transforms
  4. import matplotlib.pyplot as plt
  5. import numpy as np
  6. import torch.nn as nn
  7. import torch.nn.functional as F
  8. import torch.optim as optim
  9. from torch.utils.data import Dataset
  10.  
  11.  
  12. class Net(nn.Module):
  13.     def __init__(self):
  14.         super(Net, self).__init__()
  15.         self.conv1 = nn.Conv2d(3, 6, 5, 2, 2)
  16.         self.pool = nn.MaxPool2d(3, 2)
  17.         self.conv2 = nn.Conv2d(6, 16, 5, 2, 2)
  18.         self.fc1 = nn.Linear(784, 240)
  19.         self.fc2 = nn.Linear(240, 84)
  20.         self.fc3 = nn.Linear(84, 12)
  21.         self.dropout = nn.Dropout(p=0.5)
  22.  
  23.     def forward(self, x):
  24.         x = self.pool(F.relu(self.conv1(x)))
  25.         x = self.pool(F.relu(self.conv2(x)))
  26.         x = x.view(-1, self.num_flat_features(x))
  27.         x = self.dropout(x)
  28.         x = F.relu(self.fc1(x))
  29.         x = F.relu(self.fc2(x))
  30.         x = self.fc3(x)
  31.         return x
  32.  
  33.     def num_flat_features(self, x):
  34.         size = x.size()[1:]  # all dimensions except the batch dimension
  35.         num_features = 1
  36.         for s in size:
  37.             num_features *= s
  38.         return num_features
  39.  
  40.  
  41. class Memory(Dataset):
  42.     def __init__(self, dataset_array, dataset_labels):
  43.         self.labels = dataset_labels.astype(np.float64)
  44.         self.images = dataset_array.astype(np.float64)
  45.     def __len__(self):
  46.         return self.images.shape[0]
  47.     def __getitem__(self, idx):
  48.         return self.images[idx], self.labels[idx]
  49.  
  50.    
  51. def dataloader(batch_size, num_workers=0):
  52. #    transform0 = transforms.Compose(
  53. #    [transforms.Resize((128, 128)),
  54. #     transforms.ToTensor()
  55. #     ])
  56.    
  57. #    datasetmeanstd = torchvision.datasets.ImageFolder(root='TRUNK12', transform=transform0)
  58. #    mean, std = calculate_img_stats_full(datasetmeanstd)
  59. #    mean = mean.numpy()
  60. #    std = std.numpy()
  61.     mean, std = [(0.56337297, 0.5472399, 0.5224609), (0.13041796, 0.13301197, 0.139214)]
  62.    
  63.     train_transform = transforms.Compose(
  64.     [transforms.Resize((128, 128)),
  65.      transforms.ToTensor(),
  66.      transforms.Normalize(mean, std)
  67.      ])
  68.    
  69.     test_transform = transforms.Compose(
  70.     [transforms.Resize((128, 128)),
  71.      transforms.ToTensor(),
  72.      transforms.Normalize(mean, std)
  73.      ])
  74.    
  75.     print("Data loading...")
  76.     trainset = torchvision.datasets.ImageFolder(root='train', transform=train_transform)
  77.     trainset_length = trainset.__len__()
  78.     trainloader = torch.utils.data.DataLoader(trainset, batch_size=trainset_length, shuffle=False, num_workers=num_workers, pin_memory=False)
  79.     cvset = torchvision.datasets.ImageFolder(root='cv', transform=test_transform)
  80.     cvset_length = cvset.__len__()
  81.     cvloader = torch.utils.data.DataLoader(cvset, batch_size=cvset_length, shuffle=False, num_workers=num_workers, pin_memory=False)
  82.     testset = torchvision.datasets.ImageFolder(root='test', transform=test_transform)
  83.     testset_length = testset.__len__()
  84.     testloader = torch.utils.data.DataLoader(testset, batch_size=testset_length, shuffle=False, num_workers=num_workers, pin_memory=False)
  85.     testloader2 = torch.utils.data.DataLoader(testset, batch_size=4, shuffle=True, num_workers=num_workers, pin_memory=False)
  86.  
  87.     trainset_array = next(iter(trainloader))[0].numpy()
  88.     trainset_labels = next(iter(trainloader))[1].numpy()
  89.     trainset_m = Memory(trainset_array, trainset_labels)
  90.     trainloader_m = torch.utils.data.DataLoader(trainset_m, batch_size=batch_size, shuffle=True, num_workers=num_workers, pin_memory=False)
  91.    
  92.     cvset_array = next(iter(cvloader))[0].numpy()
  93.     cvset_labels = next(iter(cvloader))[1].numpy()
  94.     cvset_m = Memory(cvset_array, cvset_labels)
  95.     cvloader_m = torch.utils.data.DataLoader(cvset_m, batch_size=batch_size, shuffle=False, num_workers=num_workers, pin_memory=False)
  96.    
  97.     testset_array = next(iter(testloader))[0].numpy()
  98.     testset_labels = next(iter(testloader))[1].numpy()
  99.     testset_m = Memory(testset_array, testset_labels)
  100.     testloader_m = torch.utils.data.DataLoader(testset_m, batch_size=batch_size, shuffle=False, num_workers=num_workers, pin_memory=False)
  101.     print("Data loaded.")
  102.     return trainloader_m, cvloader_m, testloader_m, testloader2
  103.  
  104.  
  105. def calculate_img_stats_full(dataset):
  106.     imgs_ = torch.stack([img for img,_ in dataset],dim=3)
  107.     imgs_ = imgs_.view(3,-1)
  108.     imgs_mean = imgs_.mean(dim=1)
  109.     imgs_std = imgs_.std(dim=1)
  110.     return imgs_mean, imgs_std
  111.  
  112.  
  113. def training(trainloader, cvloader, testloader, net, opt, epochs, lr, weight_decay, step_size, gamma):
  114.     criterion = nn.CrossEntropyLoss()
  115.     adam = optim.Adam(net.parameters(), lr=lr, betas=(0.9, 0.999), eps=1e-08, weight_decay=weight_decay, amsgrad=False)
  116.     adagrad = optim.Adagrad(net.parameters(), lr=lr, lr_decay=0, weight_decay=weight_decay, initial_accumulator_value=0)
  117.     sgd = optim.SGD(net.parameters(), lr=lr, momentum=0.9, weight_decay=weight_decay)
  118.     if opt == 1:
  119.         optimizer = adam
  120.     elif opt == 2:
  121.         optimizer = adagrad
  122.     else:
  123.         optimizer = sgd
  124.     scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=step_size, gamma=gamma)
  125.  
  126.     print("Training started...")
  127.     for epoch in range(epochs):  # loop over the dataset multiple times
  128.  
  129.         running_loss = 0.0
  130.         for i, data in enumerate(trainloader, 0):
  131.             # get the inputs
  132.             inputs, labels = data
  133.             inputs = inputs.float().to(device)
  134.             labels = labels.long().to(device)
  135.  
  136.             # zero the parameter gradients
  137.             optimizer.zero_grad()
  138.  
  139.             # forward + backward + optimize
  140.             outputs = net(inputs)
  141.             loss = criterion(outputs, labels)
  142.             loss.backward()
  143.             optimizer.step()
  144.  
  145.             # print statistics
  146.             running_loss += loss.item()
  147.  
  148.         print('Epoch [%d / %d], Loss: %.10f' % (epoch + 1, epochs, running_loss))
  149.  
  150.         if epoch % 20 == 19:
  151.             print('Accuracy of the network on the train images: %d %%' % accshow(trainloader, net))
  152.             print('Accuracy of the network on the cross-validation images: %d %%' % accshow(cvloader, net))
  153.             print('Accuracy of the network on the test images: %d %%' % accshow(testloader, net))
  154.  
  155.         scheduler.step()
  156.     print('Training finished.')
  157.    
  158.    
  159. def checking(trainloader, cvloader, testloader, epochs, step_size, loops):
  160.     cv_score, cv_train, cv_test, cv_opt, cv_rate, cv_decay, test_score, test_train, test_cv, test_opt, test_rate, test_decay = [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]
  161.     opt_list = [1, 2, 3]
  162.     rate_list = [0.001, 0.003, 0.005, 0.01, 0.03, 0.05]
  163.     print("Checking started...")
  164.     for l in rate_list:
  165.         for w in rate_list:
  166.             for o in opt_list:
  167.                 for i in range(loops):
  168.                     net = Net()
  169.                     net = net.to(device)    
  170.                     training(trainloader, cvloader, testloader, net, o, epochs, l, w, step_size, gamma=0.1)
  171.                     train_acc = accshow(trainloader, net)
  172.                     cv_acc = accshow(cvloader, net)
  173.                     test_acc = accshow(testloader, net)
  174.                     print("Train: %d %%. CV: %d %%. Test: %d %%. Opt: %d. Rate: %.3f. Decay: %.3f. Loop: %d" % (train_acc, cv_acc, test_acc, o, l, w, i+1))
  175.                     if cv_acc > cv_score:
  176.                         cv_score = cv_acc
  177.                         cv_train = train_acc
  178.                         cv_test = test_acc
  179.                         cv_opt = o
  180.                         cv_rate = l
  181.                         cv_decay = w
  182.                         cv_net = net
  183.                     if test_acc > test_score:
  184.                         test_score = test_acc
  185.                         test_train = train_acc
  186.                         test_cv = cv_acc
  187.                         test_opt = o
  188.                         test_rate = l
  189.                         test_decay = w
  190.                         test_net = net
  191.     print("Checking finished.")
  192.     print("Best cross-validation set accuracy:")
  193.     print("Train: %d %%. CV: %d %%. Test: %d %%. Opt: %d. Rate: %.3f. Decay: %.3f." % (cv_train, cv_score, cv_test, cv_opt, cv_rate, cv_decay))
  194.     print("Best test set accuracy:")
  195.     print("Train: %d %%. CV: %d %%. Test: %d %%. Opt: %d. Rate: %.3f. Decay: %.3f." % (test_train, test_cv, test_score, test_opt, test_rate, test_decay))
  196.    
  197.     cv_net = cv_net.to(device)
  198.     print("Checking best cross-validation set.")
  199.     print('Accuracy of the network on the train images: %d %%' % accshow(trainloader, cv_net))
  200.     print('Accuracy of the network on the cross-validation images: %d %%' % accshow(cvloader, cv_net))
  201.     print('Accuracy of the network on the test images: %d %%' % accshow(testloader, cv_net))
  202.     test_net = test_net.to(device)
  203.     print("Checking best test set.")
  204.     print('Accuracy of the network on the train images: %d %%' % accshow(trainloader, test_net))
  205.     print('Accuracy of the network on the cross-validation images: %d %%' % accshow(cvloader, test_net))
  206.     print('Accuracy of the network on the test images: %d %%' % accshow(testloader, test_net))
  207.  
  208.    
  209. def accshow(loader, net):
  210.     correct = 0
  211.     total = 0
  212.     with torch.no_grad():
  213.         for data in loader:
  214.             inputs, labels = data
  215.             inputs = inputs.float().to(device)
  216.             labels = labels.long().to(device)
  217.             outputs = net(inputs)
  218.             _, predicted = torch.max(outputs.data, 1)
  219.             total += labels.size(0)
  220.             correct += (predicted == labels).sum().item()  
  221.     return (100 * correct / total)
  222.  
  223.  
  224. def classaccshow(loader, net, classes, classes_length):  
  225.     class_correct = list(0. for i in range(classes_length))
  226.     class_total = list(0. for i in range(classes_length))
  227.     with torch.no_grad():
  228.         for data in loader:
  229.             inputs, labels = data
  230.             inputs = inputs.float().to(device)
  231.             labels = labels.long().to(device)
  232.             outputs = net(inputs)
  233.             _, predicted = torch.max(outputs, 1)
  234.             c = (predicted == labels).squeeze()
  235.             try:
  236.                 for i in range(12):
  237.                     label = labels[i]
  238.                     class_correct[label] += c[i].item()
  239.                     class_total[label] += 1
  240.             except IndexError:
  241.                 pass  
  242.     for i in range(classes_length):
  243.         print('Accuracy of %5s : %2d %%' % (classes[i], 100 * class_correct[i] / class_total[i]))
  244.  
  245.  
  246. def imshow(img):
  247.     img = img * 0.13 + 0.5     # unnormalize
  248.     npimg = img.numpy()
  249.     plt.imshow(np.transpose(npimg, (1, 2, 0)))
  250.     plt.show()
  251.  
  252.  
  253. def imgshow(loader, net, classes):
  254.     dataiter = iter(loader)
  255.     inputs, labels = dataiter.next()
  256.     imshow(torchvision.utils.make_grid(inputs))
  257.     print('GroundTruth: ', ' '.join('%5s' % classes[labels[j]] for j in range(4)))  
  258.     inputs = inputs.to(device)
  259.     outputs = net(inputs)  
  260.     _, predicted = torch.max(outputs, 1)  
  261.     print('Predicted: ', ' '.join('%5s' % classes[predicted[j]] for j in range(4)))
  262.  
  263.  
  264.  
  265. if __name__ == '__main__':
  266.  
  267.     device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
  268.    
  269.     classes = ('alder', 'beech', 'birch', 'chestnut', 'gingko biloba', 'hornbeam', 'horse chestnut', 'linden', 'oak', 'oriental plane', 'pine', 'spruce')
  270.     classes_length = len(classes)    
  271.     batch_size = 60
  272.     trainloader, cvloader, testloader, testloader2 = dataloader(batch_size)
  273.  
  274.     net = Net()
  275.     net = net.to(device)    
  276.     epochs = 500
  277.     step_size = 400
  278.     opt = 1
  279.     learning_rate = 0.003
  280.     weight_decay = 0.03
  281.    
  282.     training(trainloader, cvloader, testloader, net, opt, epochs, learning_rate, weight_decay, step_size, gamma=0.1)
  283.    
  284. #    checking(trainloader, cvloader, testloader, epochs, step_size, loops=1)
  285.        
  286. #    classaccshow(cvloader, net, classes, classes_length)
  287. #    classaccshow(testloader, net, classes, classes_length)
  288.     imgshow(testloader2, net, classes)
RAW Paste Data
We use cookies for various purposes including analytics. By continuing to use Pastebin, you agree to our use of cookies as described in the Cookies Policy. OK, I Understand
 
Top