SHARE
TWEET

Untitled

a guest May 23rd, 2019 65 Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
  1. import torch
  2. import torchvision
  3. import torchvision.transforms as transforms
  4. import matplotlib.pyplot as plt
  5. import numpy as np
  6. import torch.nn as nn
  7. import torch.nn.functional as F
  8. import torch.optim as optim
  9. from torch.utils.data import Dataset
  10.  
  11.  
  12. class Net(nn.Module):
  13.     def __init__(self):
  14.         super(Net, self).__init__()
  15.         self.conv1 = nn.Conv2d(3, 6, 5, 2, 2)
  16.         self.pool = nn.MaxPool2d(3, 2)
  17.         self.conv2 = nn.Conv2d(6, 16, 5, 2, 2)
  18.         self.fc1 = nn.Linear(784, 240)
  19.         self.fc2 = nn.Linear(240, 84)
  20.         self.fc3 = nn.Linear(84, 12)
  21.         self.dropout = nn.Dropout(p=0.5)
  22.  
  23.     def forward(self, x):
  24.         x = self.pool(F.relu(self.conv1(x)))
  25.         x = self.pool(F.relu(self.conv2(x)))
  26.         x = x.view(-1, self.num_flat_features(x))
  27.         x = self.dropout(x)
  28.         x = F.relu(self.fc1(x))
  29.         x = F.relu(self.fc2(x))
  30.         x = self.fc3(x)
  31.         return x
  32.  
  33.     def num_flat_features(self, x):
  34.         size = x.size()[1:]  # all dimensions except the batch dimension
  35.         num_features = 1
  36.         for s in size:
  37.             num_features *= s
  38.         return num_features
  39.  
  40. class Memory(Dataset):
  41.     def __init__(self, dataset_array, dataset_labels):
  42.         self.labels = dataset_labels.astype(np.float64)
  43.         self.images = dataset_array.astype(np.float64)
  44.     def __len__(self):
  45.         return self.images.shape[0]
  46.     def __getitem__(self, idx):
  47.         return self.images[idx], self.labels[idx]
  48.  
  49. def training(epochs, lr, weight_decay, step_size, gamma):
  50.     criterion = nn.CrossEntropyLoss()
  51. #    optimizer = optim.Adam(net.parameters(), lr=lr, betas=(0.9, 0.999), eps=1e-08, weight_decay=weight_decay, amsgrad=False)
  52. #    optimizer = optim.SGD(net.parameters(), lr=lr, momentum=0.9, weight_decay=weight_decay)
  53.     optimizer = optim.Adagrad(net.parameters(), lr=lr, lr_decay=0, weight_decay=weight_decay, initial_accumulator_value=0)
  54.     scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=step_size, gamma=gamma)
  55.  
  56.     print("Training started...")
  57.     for epoch in range(epochs):  # loop over the dataset multiple times
  58.  
  59.         running_loss = 0.0
  60.         for i, data in enumerate(trainloader, 0):
  61.             # get the inputs
  62.             inputs, labels = data
  63.             inputs = inputs.float().to(device)
  64.             labels = labels.long().to(device)
  65.  
  66.             # zero the parameter gradients
  67.             optimizer.zero_grad()
  68.  
  69.             # forward + backward + optimize
  70.             outputs = net(inputs)
  71.             loss = criterion(outputs, labels)
  72.             loss.backward()
  73.             optimizer.step()
  74.  
  75.             # print statistics
  76.             running_loss += loss.item()
  77. #            if i % int((trainset_length // batch_size) / 1) == int(((trainset_length // batch_size) / 1) - 1):
  78. #                print('Epoch [%d / %d], Step: %d, Loss: %.10f' %
  79. #                      (epoch + 1, epochs, i + 1, running_loss / int((trainset_length // batch_size) / 1)))
  80. #                running_loss = 0.0
  81.  
  82.         print('Epoch [%d / %d], Loss: %.10f' %
  83.         (epoch + 1, epochs, running_loss / int((trainset_length // batch_size) / 1)))
  84.  
  85.         if epoch % 20 == 19:
  86.             print('Accuracy of the network on the train images: %d %%' % accshow(trainloader))
  87.             print('Accuracy of the network on the cross-validation images: %d %%' % accshow(cvloader))
  88.             print('Accuracy of the network on the test images: %d %%' % accshow(testloader))
  89.  
  90.         scheduler.step()
  91.     print('Training finished.')
  92.    
  93. def accshow(loader):
  94.     correct = 0
  95.     total = 0
  96.     with torch.no_grad():
  97.         for data in loader:
  98.             inputs, labels = data
  99.             inputs = inputs.float().to(device)
  100.             labels = labels.long().to(device)
  101.             outputs = net(inputs)
  102.             _, predicted = torch.max(outputs.data, 1)
  103.             total += labels.size(0)
  104.             correct += (predicted == labels).sum().item()  
  105.     return (100 * correct / total)
  106.  
  107. def classaccshow(loader):  
  108.     class_correct = list(0. for i in range(classes_length))
  109.     class_total = list(0. for i in range(classes_length))
  110.     with torch.no_grad():
  111.         for data in loader:
  112.             inputs, labels = data
  113.             inputs = inputs.float().to(device)
  114.             labels = labels.long().to(device)
  115.             outputs = net(inputs)
  116.             _, predicted = torch.max(outputs, 1)
  117.             c = (predicted == labels).squeeze()
  118.             try:
  119.                 for i in range(classes_length*2):
  120.                     label = labels[i]
  121.                     class_correct[label] += c[i].item()
  122.                     class_total[label] += 1
  123.             except IndexError:
  124.                 pass  
  125.     for i in range(classes_length):
  126.         print('Accuracy of %5s : %2d %%' % (classes[i], 100 * class_correct[i] / class_total[i]))
  127.  
  128.  
  129. def imshow(img):
  130.     img = img / 2 + 0.5     # unnormalize
  131.     npimg = img.numpy()
  132.     plt.imshow(np.transpose(npimg, (1, 2, 0)))
  133.     plt.show()
  134.  
  135. def imgshow(loader):
  136.     dataiter = iter(loader)
  137.     inputs, labels = dataiter.next()  
  138.     imshow(torchvision.utils.make_grid(inputs))
  139.     print('GroundTruth: ', ' '.join('%5s' % classes[labels[j]] for j in range(4)))  
  140.     inputs = inputs.to(device)
  141.     outputs = net(inputs)  
  142.     _, predicted = torch.max(outputs, 1)  
  143.     print('Predicted: ', ' '.join('%5s' % classes[predicted[j]] for j in range(4)))
  144.  
  145.  
  146. if __name__ == '__main__':
  147.  
  148.     device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
  149.  
  150.     transform1 = transforms.Compose(
  151.     [transforms.Resize((128, 128)),
  152.      transforms.ToTensor(),
  153.      transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
  154.      ])
  155.  
  156.     transform2 = transforms.Compose(
  157.     [transforms.Resize((128, 128)),
  158.      transforms.RandomVerticalFlip(p=1),
  159.      transforms.ToTensor(),
  160.      transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
  161.      ])
  162.  
  163.     transform3 = transforms.Compose(
  164.     [transforms.Resize((128, 128)),
  165.      transforms.ColorJitter(brightness=0, contrast=0, saturation=0, hue=0),
  166.      transforms.ToTensor(),
  167.      transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
  168.      ])
  169.  
  170.     transform4 = transforms.Compose(
  171.     [transforms.CenterCrop(128),
  172.      transforms.ToTensor(),
  173.      transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
  174.      ])
  175.  
  176.     transform5 = transforms.Compose(
  177.     [transforms.Resize((128, 128)),
  178.      transforms.RandomVerticalFlip(p=1),
  179.      transforms.ColorJitter(brightness=0, contrast=0, saturation=0, hue=0),
  180.      transforms.ToTensor(),
  181.      transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
  182.      ])
  183.    
  184.     dataset1 = torchvision.datasets.ImageFolder(root='TRUNK12', transform=transform1, target_transform=None)
  185.     dataset2 = torchvision.datasets.ImageFolder(root='TRUNK12', transform=transform2, target_transform=None)
  186.     dataset3 = torchvision.datasets.ImageFolder(root='TRUNK12', transform=transform3, target_transform=None)
  187.     dataset4 = torchvision.datasets.ImageFolder(root='TRUNK12', transform=transform4, target_transform=None)
  188.     dataset5 = torchvision.datasets.ImageFolder(root='TRUNK12', transform=transform5, target_transform=None)
  189.    
  190.     batch_size = 20
  191.    
  192.     print("Data loading...")
  193.     dataset = dataset1 #+ dataset2 + dataset3 + dataset4 + dataset5
  194.     dataset_length = dataset.__len__()
  195.     datasetloader = torch.utils.data.DataLoader(dataset, batch_size=dataset_length, shuffle=False)
  196.     dataset_array = next(iter(datasetloader))[0].numpy()
  197.     dataset_labels = next(iter(datasetloader))[1].numpy()
  198.     dataset_m = Memory(dataset_array, dataset_labels)
  199.    
  200.     trainset_length = int(0.7*dataset_length)
  201.     cvset_length = dataset_length - trainset_length
  202.     trainset, cvset = torch.utils.data.random_split(dataset_m, (trainset_length, cvset_length))
  203.     trainloader = torch.utils.data.DataLoader(trainset, batch_size=batch_size, shuffle=True)
  204.     cvloader = torch.utils.data.DataLoader(cvset, batch_size=batch_size, shuffle=False)
  205.    
  206.     testset = torchvision.datasets.ImageFolder(root='test', transform=transform1, target_transform=None)
  207.     testset_length = testset.__len__()
  208.     testsetloader = torch.utils.data.DataLoader(testset, batch_size=testset_length, shuffle=False)
  209.     testset_array = next(iter(testsetloader))[0].numpy()
  210.     testset_labels = next(iter(testsetloader))[1].numpy()
  211.     testset_m = Memory(testset_array, testset_labels)
  212.     testloader = torch.utils.data.DataLoader(testset_m, batch_size=testset_length, shuffle=False)
  213.     testloader2 = torch.utils.data.DataLoader(testset, batch_size=4, shuffle=True)
  214.     print("Data loaded.")
  215.  
  216.     classes = ('alder', 'beech', 'birch', 'chestnut', 'gingko biloba', 'hornbeam', 'horse chestnut', 'linden', 'oak', 'oriental plane', 'pine', 'spruce')
  217.     classes_length = len(classes)
  218.  
  219.     net = Net()
  220.     net = net.to(device)
  221.    
  222.     training(epochs=1000, lr=0.003, weight_decay=0.003, step_size=500, gamma=0.1)
  223.  
  224. #    classaccshow(cvloader)
  225. #    classaccshow(testloader)
  226. #    imgshow(testloader2)
RAW Paste Data
We use cookies for various purposes including analytics. By continuing to use Pastebin, you agree to our use of cookies as described in the Cookies Policy. OK, I Understand
 
Top