Advertisement
Guest User

Untitled

a guest
May 19th, 2019
78
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
Python 7.03 KB | None | 0 0
  1. import torch
  2. import torchvision
  3. import torchvision.transforms as transforms
  4. import matplotlib.pyplot as plt
  5. import numpy as np
  6. import torch.nn as nn
  7. import torch.nn.functional as F
  8. import torch.optim as optim
  9.  
  10.  
  11. class Net(nn.Module):
  12.     def __init__(self):
  13.         super(Net, self).__init__()
  14.         self.conv1 = nn.Conv2d(3, 8, 8)
  15.         self.pool = nn.MaxPool2d(2, 2)
  16.         self.conv2 = nn.Conv2d(8, 16, 8)
  17.         self.fc1 = nn.Linear(53824, 120)
  18.         self.fc2 = nn.Linear(120, 84)
  19.         self.fc3 = nn.Linear(84, 12)
  20.  
  21.     def forward(self, x):
  22.         x = self.pool(F.relu(self.conv1(x)))
  23.         x = self.pool(F.relu(self.conv2(x)))
  24.         x = x.view(-1, self.num_flat_features(x))
  25.         x = F.relu(self.fc1(x))
  26.         x = F.relu(self.fc2(x))
  27.         x = self.fc3(x)
  28.         return x
  29.  
  30.     def num_flat_features(self, x):
  31.         size = x.size()[1:]  # all dimensions except the batch dimension
  32.         num_features = 1
  33.         for s in size:
  34.             num_features *= s
  35.         return num_features
  36.  
  37. def imshow(img):
  38.     img = img / 2 + 0.5     # unnormalize
  39.     npimg = img.numpy()
  40.     plt.imshow(np.transpose(npimg, (1, 2, 0)))
  41.     plt.show()
  42.  
  43. def imgshow(loader):
  44.     dataiter = iter(loader)
  45.     inputs, labels = dataiter.next()  
  46.     imshow(torchvision.utils.make_grid(inputs))
  47.     print('GroundTruth: ', ' '.join('%5s' % classes[labels[j]] for j in range(4)))  
  48.     inputs = inputs.to(device)
  49.     outputs = net(inputs)  
  50.     _, predicted = torch.max(outputs, 1)  
  51.     print('Predicted: ', ' '.join('%5s' % classes[predicted[j]] for j in range(4)))
  52.  
  53. def accshow(loader, which):
  54.     correct = 0
  55.     total = 0
  56.     with torch.no_grad():
  57.         for data in loader:
  58.             inputs, labels = data
  59.             inputs = inputs.to(device)
  60.             labels = labels.to(device)
  61.             outputs = net(inputs)
  62.             _, predicted = torch.max(outputs.data, 1)
  63.             total += labels.size(0)
  64.             correct += (predicted == labels).sum().item()  
  65.     print('Accuracy of the network on the', which, 'images: %d %%' % (100 * correct / total))
  66.  
  67. def classaccshow(loader):  
  68.     class_correct = list(0. for i in range(classes_length))
  69.     class_total = list(0. for i in range(classes_length))
  70.     with torch.no_grad():
  71.         for data in loader:
  72.             inputs, labels = data
  73.             inputs = inputs.to(device)
  74.             labels = labels.to(device)
  75.             outputs = net(inputs)
  76.             _, predicted = torch.max(outputs, 1)
  77.             c = (predicted == labels).squeeze()
  78.             try:
  79.                 for i in range(classes_length*2):
  80.                     label = labels[i]
  81.                     class_correct[label] += c[i].item()
  82.                     class_total[label] += 1
  83.             except IndexError:
  84.                 pass  
  85.     for i in range(classes_length):
  86.         print('Accuracy of %5s : %2d %%' % (classes[i], 100 * class_correct[i] / class_total[i]))
  87.  
  88. if __name__ == '__main__':
  89.  
  90.     device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
  91.  
  92.     transform1 = transforms.Compose(
  93.     [transforms.Resize((256, 256)),
  94.      transforms.ToTensor(),
  95.      transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
  96.      ])
  97.  
  98.     transform2 = transforms.Compose(
  99.     [transforms.Resize((256, 256)),
  100.      transforms.RandomVerticalFlip(p=1),
  101.      transforms.ToTensor(),
  102.      transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
  103.      ])
  104.  
  105.     transform3 = transforms.Compose(
  106.     [transforms.Resize((256, 256)),
  107.      transforms.ColorJitter(brightness=0, contrast=0, saturation=0, hue=0),
  108.      transforms.ToTensor(),
  109.      transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
  110.      ])
  111.  
  112.     transform4 = transforms.Compose(
  113.     [transforms.Resize((256, 256)),
  114.      transforms.CenterCrop(256),
  115.      transforms.ToTensor(),
  116.      transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
  117.      ])
  118.  
  119.     transform5 = transforms.Compose(
  120.     [transforms.Resize((256, 256)),
  121.      transforms.RandomVerticalFlip(p=1),
  122.      transforms.ColorJitter(brightness=0, contrast=0, saturation=0, hue=0),
  123.      transforms.ToTensor(),
  124.      transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
  125.      ])
  126.  
  127.     dataset1 = torchvision.datasets.ImageFolder(root='TRUNK12', transform=transform1, target_transform=None)
  128.     dataset2 = torchvision.datasets.ImageFolder(root='TRUNK12', transform=transform2, target_transform=None)
  129.     dataset3 = torchvision.datasets.ImageFolder(root='TRUNK12', transform=transform3, target_transform=None)
  130.     dataset4 = torchvision.datasets.ImageFolder(root='TRUNK12', transform=transform4, target_transform=None)
  131.     dataset5 = torchvision.datasets.ImageFolder(root='TRUNK12', transform=transform5, target_transform=None)
  132.     dataset = dataset1 + dataset2 + dataset3 + dataset4 + dataset5
  133.     dataset_length = dataset.__len__()
  134.     trainset_length = int(0.7*dataset_length)
  135.     testset_length = dataset_length - trainset_length
  136.     trainset, testset = torch.utils.data.random_split(dataset, (trainset_length, testset_length))
  137.     truetestset = torchvision.datasets.ImageFolder(root='test', transform=transform1, target_transform=None)
  138.     batch_size = 10
  139.     trainloader = torch.utils.data.DataLoader(trainset, batch_size=batch_size, shuffle=True, num_workers=2)
  140.     testloader = torch.utils.data.DataLoader(testset, batch_size=batch_size, shuffle=False, num_workers=2)
  141.     truetestloader = torch.utils.data.DataLoader(truetestset, batch_size=4, shuffle=True, num_workers=2)
  142.  
  143.     classes = ('alder', 'beech', 'birch', 'chestnut', 'gingko biloba', 'hornbeam', 'horse chestnut', 'linden', 'oak', 'oriental plane', 'pine', 'spruce')
  144.     classes_length = len(classes)
  145.  
  146.     net = Net()
  147.     net = net.to(device)
  148.  
  149.     criterion = nn.CrossEntropyLoss()
  150.     optimizer = optim.SGD(net.parameters(), lr=0.001, momentum=0.9, weight_decay=0)
  151.     scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=100, gamma=0.5)
  152.  
  153.     for epoch in range(100):  # loop over the dataset multiple times
  154.  
  155.         running_loss = 0.0
  156.         for i, data in enumerate(trainloader, 0):
  157.             # get the inputs
  158.             inputs, labels = data
  159.             inputs = inputs.to(device)
  160.             labels = labels.to(device)
  161.  
  162.             # zero the parameter gradients
  163.             optimizer.zero_grad()
  164.  
  165.             # forward + backward + optimize
  166.             outputs = net(inputs)
  167.             loss = criterion(outputs, labels)
  168.             loss.backward()
  169.             optimizer.step()
  170.  
  171.             # print statistics
  172.             running_loss += loss.item()
  173.             if i % 100 == 99:
  174.                 print('[%d, %5d] loss: %.10f' %
  175.                       (epoch + 1, i + 1, running_loss / 100))
  176.                 running_loss = 0.0
  177.  
  178.         if epoch % 5 == 4:
  179.             accshow(trainloader, 'train')
  180.             accshow(testloader, 'test')
  181.             accshow(truetestloader, 'truetest')
  182.  
  183.         scheduler.step()
  184.     print('Finished Training')
  185.  
  186.     classaccshow(testloader)
  187.  
  188.     accshow(truetestloader, 'truetest')
  189.     classaccshow(truetestloader)
  190.     imgshow(truetestloader)
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement