SHARE
TWEET

Untitled

a guest May 22nd, 2019 63 Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
  1. import torch
  2. import torchvision
  3. import torchvision.transforms as transforms
  4. import matplotlib.pyplot as plt
  5. import numpy as np
  6. import torch.nn as nn
  7. import torch.nn.functional as F
  8. import torch.optim as optim
  9. from torch.utils.data import Dataset
  10.  
  11.  
  12. class Net(nn.Module):
  13.     def __init__(self):
  14.         super(Net, self).__init__()
  15.         self.conv1 = nn.Conv2d(3, 6, 5, 2, 2)
  16.         self.pool = nn.MaxPool2d(3, 2)
  17.         self.conv2 = nn.Conv2d(6, 16, 5, 2, 2)
  18.         self.fc1 = nn.Linear(784, 240)
  19.         self.fc2 = nn.Linear(240, 84)
  20.         self.fc3 = nn.Linear(84, 12)
  21.         self.dropout = nn.Dropout(p=0.5)
  22.  
  23.     def forward(self, x):
  24.         x = self.pool(F.relu(self.conv1(x)))
  25.         x = self.pool(F.relu(self.conv2(x)))
  26.         x = x.view(-1, self.num_flat_features(x))
  27.         x = self.dropout(x)
  28.         x = F.relu(self.fc1(x))
  29.         x = F.relu(self.fc2(x))
  30.         x = self.fc3(x)
  31.         return x
  32.  
  33.     def num_flat_features(self, x):
  34.         size = x.size()[1:]  # all dimensions except the batch dimension
  35.         num_features = 1
  36.         for s in size:
  37.             num_features *= s
  38.         return num_features
  39.  
  40. class Memory(Dataset):
  41.     def __init__(self, dataset_array, dataset_labels):
  42.         self.labels = dataset_labels.astype(np.float64)
  43.         self.images = dataset_array.astype(np.float64)
  44.     def __len__(self):
  45.         return self.images.shape[0]
  46.     def __getitem__(self, idx):
  47.         return self.images[idx], self.labels[idx]
  48.  
  49. def training(epochs, lr, weight_decay, step_size, gamma):
  50.     criterion = nn.CrossEntropyLoss()
  51.     optimizer = optim.SGD(net.parameters(), lr=lr, momentum=0.9, weight_decay=weight_decay)
  52.     scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=step_size, gamma=gamma)
  53.  
  54.     print("Training started...")
  55.     for epoch in range(epochs):  # loop over the dataset multiple times
  56.  
  57.         running_loss = 0.0
  58.         for i, data in enumerate(trainloader, 0):
  59.             # get the inputs
  60.             inputs, labels = data
  61.             inputs = inputs.float().to(device)
  62.             labels = labels.long().to(device)
  63.  
  64.             # zero the parameter gradients
  65.             optimizer.zero_grad()
  66.  
  67.             # forward + backward + optimize
  68.             outputs = net(inputs)
  69.             loss = criterion(outputs, labels)
  70.             loss.backward()
  71.             optimizer.step()
  72.  
  73.             # print statistics
  74.             running_loss += loss.item()
  75.             if i % int((trainset_length // batch_size) / 1) == int(((trainset_length // batch_size) / 1) - 1):
  76.                 print('Epoch [%d / %d], Step: %d, Loss: %.10f' %
  77.                       (epoch + 1, epochs, i + 1, running_loss / int((trainset_length // batch_size) / 1)))
  78.                 running_loss = 0.0
  79.  
  80.         if epoch % 5 == 4:
  81.             accshow(trainloader, 'train')
  82.             accshow(cvloader, 'cross-validation')
  83.             accshow(testloader, 'test')
  84.  
  85.         scheduler.step()
  86.     print('Training finished.')
  87.  
  88. def imshow(img):
  89.     img = img / 2 + 0.5     # unnormalize
  90.     npimg = img.numpy()
  91.     plt.imshow(np.transpose(npimg, (1, 2, 0)))
  92.     plt.show()
  93.  
  94. def imgshow(loader):
  95.     dataiter = iter(loader)
  96.     inputs, labels = dataiter.next()  
  97.     imshow(torchvision.utils.make_grid(inputs))
  98.     print('GroundTruth: ', ' '.join('%5s' % classes[labels[j]] for j in range(4)))  
  99.     inputs = inputs.to(device)
  100.     outputs = net(inputs)  
  101.     _, predicted = torch.max(outputs, 1)  
  102.     print('Predicted: ', ' '.join('%5s' % classes[predicted[j]] for j in range(4)))
  103.  
  104. def accshow(loader, which):
  105.     correct = 0
  106.     total = 0
  107.     with torch.no_grad():
  108.         for data in loader:
  109.             inputs, labels = data
  110.             inputs = inputs.float().to(device)
  111.             labels = labels.long().to(device)
  112.             outputs = net(inputs)
  113.             _, predicted = torch.max(outputs.data, 1)
  114.             total += labels.size(0)
  115.             correct += (predicted == labels).sum().item()  
  116.     print('Accuracy of the network on the', which, 'images: %d %%' % (100 * correct / total))
  117.  
  118. def classaccshow(loader):  
  119.     class_correct = list(0. for i in range(classes_length))
  120.     class_total = list(0. for i in range(classes_length))
  121.     with torch.no_grad():
  122.         for data in loader:
  123.             inputs, labels = data
  124.             inputs = inputs.float().to(device)
  125.             labels = labels.long().to(device)
  126.             outputs = net(inputs)
  127.             _, predicted = torch.max(outputs, 1)
  128.             c = (predicted == labels).squeeze()
  129.             try:
  130.                 for i in range(classes_length*2):
  131.                     label = labels[i]
  132.                     class_correct[label] += c[i].item()
  133.                     class_total[label] += 1
  134.             except IndexError:
  135.                 pass  
  136.     for i in range(classes_length):
  137.         print('Accuracy of %5s : %2d %%' % (classes[i], 100 * class_correct[i] / class_total[i]))
  138.  
  139. if __name__ == '__main__':
  140.  
  141.     device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
  142.  
  143.     transform1 = transforms.Compose(
  144.     [transforms.Resize((128, 128)),
  145.      transforms.ToTensor(),
  146.      transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
  147.      ])
  148.  
  149.     transform2 = transforms.Compose(
  150.     [transforms.Resize((128, 128)),
  151.      transforms.RandomVerticalFlip(p=1),
  152.      transforms.ToTensor(),
  153.      transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
  154.      ])
  155.  
  156.     transform3 = transforms.Compose(
  157.     [transforms.Resize((128, 128)),
  158.      transforms.ColorJitter(brightness=0, contrast=0, saturation=0, hue=0),
  159.      transforms.ToTensor(),
  160.      transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
  161.      ])
  162.  
  163.     transform4 = transforms.Compose(
  164.     [transforms.CenterCrop(128),
  165.      transforms.ToTensor(),
  166.      transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
  167.      ])
  168.  
  169.     transform5 = transforms.Compose(
  170.     [transforms.Resize((128, 128)),
  171.      transforms.RandomVerticalFlip(p=1),
  172.      transforms.ColorJitter(brightness=0, contrast=0, saturation=0, hue=0),
  173.      transforms.ToTensor(),
  174.      transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
  175.      ])
  176.    
  177.     dataset1 = torchvision.datasets.ImageFolder(root='TRUNK12', transform=transform1, target_transform=None)
  178.     dataset2 = torchvision.datasets.ImageFolder(root='TRUNK12', transform=transform2, target_transform=None)
  179.     dataset3 = torchvision.datasets.ImageFolder(root='TRUNK12', transform=transform3, target_transform=None)
  180.     dataset4 = torchvision.datasets.ImageFolder(root='TRUNK12', transform=transform4, target_transform=None)
  181.     dataset5 = torchvision.datasets.ImageFolder(root='TRUNK12', transform=transform5, target_transform=None)
  182.    
  183.     batch_size = 20
  184.    
  185.     print("Data loading...")
  186.     dataset = dataset1 #+ dataset2 + dataset3 + dataset4 + dataset5
  187.     dataset_length = dataset.__len__()
  188.     datasetloader = torch.utils.data.DataLoader(dataset, batch_size=dataset_length, shuffle=False)
  189.     dataset_array = next(iter(datasetloader))[0].numpy()
  190.     dataset_labels = next(iter(datasetloader))[1].numpy()
  191.     dataset_m = Memory(dataset_array, dataset_labels)
  192.    
  193.     trainset_length = int(0.7*dataset_length)
  194.     cvset_length = dataset_length - trainset_length
  195.     trainset, cvset = torch.utils.data.random_split(dataset_m, (trainset_length, cvset_length))
  196.     trainloader = torch.utils.data.DataLoader(trainset, batch_size=batch_size, shuffle=True)
  197.     cvloader = torch.utils.data.DataLoader(cvset, batch_size=batch_size, shuffle=False)
  198.    
  199.     testset = torchvision.datasets.ImageFolder(root='test', transform=transform1, target_transform=None)
  200.     testset_length = testset.__len__()
  201.     testsetloader = torch.utils.data.DataLoader(testset, batch_size=testset_length, shuffle=False)
  202.     testset_array = next(iter(testsetloader))[0].numpy()
  203.     testset_labels = next(iter(testsetloader))[1].numpy()
  204.     testset_m = Memory(testset_array, testset_labels)
  205.     testloader = torch.utils.data.DataLoader(testset_m, batch_size=testset_length, shuffle=False)
  206.     testloader2 = torch.utils.data.DataLoader(testset, batch_size=4, shuffle=True)
  207.     print("Data loaded.")
  208.  
  209.     classes = ('alder', 'beech', 'birch', 'chestnut', 'gingko biloba', 'hornbeam', 'horse chestnut', 'linden', 'oak', 'oriental plane', 'pine', 'spruce')
  210.     classes_length = len(classes)
  211.  
  212.     net = Net()
  213.     net = net.to(device)
  214.    
  215.     training(epochs=400, lr=0.001, weight_decay=0, step_size=200, gamma=0.1)
  216.  
  217.     classaccshow(cvloader)
  218.  
  219.     accshow(testloader, 'test')
  220.     classaccshow(testloader)
  221.     imgshow(testloader2)
RAW Paste Data
We use cookies for various purposes including analytics. By continuing to use Pastebin, you agree to our use of cookies as described in the Cookies Policy. OK, I Understand
 
Top