SHARE
TWEET

Untitled

a guest May 19th, 2019 56 Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
  1. import torch
  2. import torchvision
  3. import torchvision.transforms as transforms
  4. import matplotlib.pyplot as plt
  5. import numpy as np
  6. import torch.nn as nn
  7. import torch.nn.functional as F
  8. import torch.optim as optim
  9.  
  10.  
  11. class Net(nn.Module):
  12.     def __init__(self):
  13.         super(Net, self).__init__()
  14.         self.conv1 = nn.Conv2d(3, 8, 8)
  15.         self.pool = nn.MaxPool2d(2, 2)
  16.         self.conv2 = nn.Conv2d(8, 16, 8)
  17.         self.fc1 = nn.Linear(53824, 120)
  18.         self.fc2 = nn.Linear(120, 84)
  19.         self.fc3 = nn.Linear(84, 12)
  20.  
  21.     def forward(self, x):
  22.         x = self.pool(F.relu(self.conv1(x)))
  23.         x = self.pool(F.relu(self.conv2(x)))
  24.         x = x.view(-1, self.num_flat_features(x))
  25.         x = F.relu(self.fc1(x))
  26.         x = F.relu(self.fc2(x))
  27.         x = self.fc3(x)
  28.         return x
  29.  
  30.     def num_flat_features(self, x):
  31.         size = x.size()[1:]  # all dimensions except the batch dimension
  32.         num_features = 1
  33.         for s in size:
  34.             num_features *= s
  35.         return num_features
  36.  
  37. def imshow(img):
  38.     img = img / 2 + 0.5     # unnormalize
  39.     npimg = img.numpy()
  40.     plt.imshow(np.transpose(npimg, (1, 2, 0)))
  41.     plt.show()
  42.  
  43. def imgshow(loader):
  44.     dataiter = iter(loader)
  45.     inputs, labels = dataiter.next()  
  46.     imshow(torchvision.utils.make_grid(inputs))
  47.     print('GroundTruth: ', ' '.join('%5s' % classes[labels[j]] for j in range(4)))  
  48.     inputs = inputs.to(device)
  49.     outputs = net(inputs)  
  50.     _, predicted = torch.max(outputs, 1)  
  51.     print('Predicted: ', ' '.join('%5s' % classes[predicted[j]] for j in range(4)))
  52.  
  53. def accshow(loader, which):
  54.     correct = 0
  55.     total = 0
  56.     with torch.no_grad():
  57.         for data in loader:
  58.             inputs, labels = data
  59.             inputs = inputs.to(device)
  60.             labels = labels.to(device)
  61.             outputs = net(inputs)
  62.             _, predicted = torch.max(outputs.data, 1)
  63.             total += labels.size(0)
  64.             correct += (predicted == labels).sum().item()  
  65.     print('Accuracy of the network on the', which, 'images: %d %%' % (100 * correct / total))
  66.  
  67. def classaccshow(loader):  
  68.     class_correct = list(0. for i in range(classes_length))
  69.     class_total = list(0. for i in range(classes_length))
  70.     with torch.no_grad():
  71.         for data in loader:
  72.             inputs, labels = data
  73.             inputs = inputs.to(device)
  74.             labels = labels.to(device)
  75.             outputs = net(inputs)
  76.             _, predicted = torch.max(outputs, 1)
  77.             c = (predicted == labels).squeeze()
  78.             try:
  79.                 for i in range(classes_length*2):
  80.                     label = labels[i]
  81.                     class_correct[label] += c[i].item()
  82.                     class_total[label] += 1
  83.             except IndexError:
  84.                 pass  
  85.     for i in range(classes_length):
  86.         print('Accuracy of %5s : %2d %%' % (classes[i], 100 * class_correct[i] / class_total[i]))
  87.  
  88. if __name__ == '__main__':
  89.  
  90.     device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
  91.  
  92.     transform1 = transforms.Compose(
  93.     [transforms.Resize((256, 256)),
  94.      transforms.ToTensor(),
  95.      transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
  96.      ])
  97.  
  98.     transform2 = transforms.Compose(
  99.     [transforms.Resize((256, 256)),
  100.      transforms.RandomVerticalFlip(p=1),
  101.      transforms.ToTensor(),
  102.      transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
  103.      ])
  104.  
  105.     transform3 = transforms.Compose(
  106.     [transforms.Resize((256, 256)),
  107.      transforms.ColorJitter(brightness=0, contrast=0, saturation=0, hue=0),
  108.      transforms.ToTensor(),
  109.      transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
  110.      ])
  111.  
  112.     transform4 = transforms.Compose(
  113.     [transforms.Resize((256, 256)),
  114.      transforms.CenterCrop(256),
  115.      transforms.ToTensor(),
  116.      transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
  117.      ])
  118.  
  119.     transform5 = transforms.Compose(
  120.     [transforms.Resize((256, 256)),
  121.      transforms.RandomVerticalFlip(p=1),
  122.      transforms.ColorJitter(brightness=0, contrast=0, saturation=0, hue=0),
  123.      transforms.ToTensor(),
  124.      transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
  125.      ])
  126.  
  127.     dataset1 = torchvision.datasets.ImageFolder(root='TRUNK12', transform=transform1, target_transform=None)
  128.     dataset2 = torchvision.datasets.ImageFolder(root='TRUNK12', transform=transform2, target_transform=None)
  129.     dataset3 = torchvision.datasets.ImageFolder(root='TRUNK12', transform=transform3, target_transform=None)
  130.     dataset4 = torchvision.datasets.ImageFolder(root='TRUNK12', transform=transform4, target_transform=None)
  131.     dataset5 = torchvision.datasets.ImageFolder(root='TRUNK12', transform=transform5, target_transform=None)
  132.     dataset = dataset1 + dataset2 + dataset3 + dataset4 + dataset5
  133.     dataset_length = dataset.__len__()
  134.     trainset_length = int(0.7*dataset_length)
  135.     testset_length = dataset_length - trainset_length
  136.     trainset, testset = torch.utils.data.random_split(dataset, (trainset_length, testset_length))
  137.     truetestset = torchvision.datasets.ImageFolder(root='test', transform=transform1, target_transform=None)
  138.     batch_size = 10
  139.     trainloader = torch.utils.data.DataLoader(trainset, batch_size=batch_size, shuffle=True, num_workers=2)
  140.     testloader = torch.utils.data.DataLoader(testset, batch_size=batch_size, shuffle=False, num_workers=2)
  141.     truetestloader = torch.utils.data.DataLoader(truetestset, batch_size=4, shuffle=True, num_workers=2)
  142.  
  143.     classes = ('alder', 'beech', 'birch', 'chestnut', 'gingko biloba', 'hornbeam', 'horse chestnut', 'linden', 'oak', 'oriental plane', 'pine', 'spruce')
  144.     classes_length = len(classes)
  145.  
  146.     net = Net()
  147.     net = net.to(device)
  148.  
  149.     criterion = nn.CrossEntropyLoss()
  150.     optimizer = optim.SGD(net.parameters(), lr=0.001, momentum=0.9, weight_decay=0)
  151.     scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=100, gamma=0.5)
  152.  
  153.     for epoch in range(100):  # loop over the dataset multiple times
  154.  
  155.         running_loss = 0.0
  156.         for i, data in enumerate(trainloader, 0):
  157.             # get the inputs
  158.             inputs, labels = data
  159.             inputs = inputs.to(device)
  160.             labels = labels.to(device)
  161.  
  162.             # zero the parameter gradients
  163.             optimizer.zero_grad()
  164.  
  165.             # forward + backward + optimize
  166.             outputs = net(inputs)
  167.             loss = criterion(outputs, labels)
  168.             loss.backward()
  169.             optimizer.step()
  170.  
  171.             # print statistics
  172.             running_loss += loss.item()
  173.             if i % 100 == 99:
  174.                 print('[%d, %5d] loss: %.10f' %
  175.                       (epoch + 1, i + 1, running_loss / 100))
  176.                 running_loss = 0.0
  177.  
  178.         if epoch % 5 == 4:
  179.             accshow(trainloader, 'train')
  180.             accshow(testloader, 'test')
  181.             accshow(truetestloader, 'truetest')
  182.  
  183.         scheduler.step()
  184.     print('Finished Training')
  185.  
  186.     classaccshow(testloader)
  187.  
  188.     accshow(truetestloader, 'truetest')
  189.     classaccshow(truetestloader)
  190.     imgshow(truetestloader)
RAW Paste Data
We use cookies for various purposes including analytics. By continuing to use Pastebin, you agree to our use of cookies as described in the Cookies Policy. OK, I Understand
 
Top