daily pastebin goal
25%
SHARE
TWEET

Untitled

a guest Dec 9th, 2018 62 Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
  1. import os
  2. from wheel.pep425tags import get_abbr_impl, get_impl_ver, get_abi_tag
  3. platform = '{}{}-{}'.format(get_abbr_impl(), get_impl_ver(), get_abi_tag())
  4.  
  5. accelerator = 'cu80' if os.path.exists('/opt/bin/nvidia-smi') else 'cpu'
  6.  
  7. import torch
  8. print('Version', torch.__version__)
  9. print('CUDA enabled:', torch.cuda.is_available())
  10.  
  11.  
  12. # Part 1: Upload the Dataset
  13.  
  14. BASE_PATH = "/tmp/490/"
  15. DATA_PATH = BASE_PATH + 'cartoon_network/'
  16.  
  17. os.chdir(BASE_PATH)
  18.  
  19. import torch
  20. import torch.nn as nn
  21. from torchvision import datasets
  22. from torchvision import transforms
  23. import numpy as np
  24. import os
  25. import torch.nn.functional as F
  26. import torch.optim as optim
  27. import h5py
  28. import sys
  29. sys.path.append(BASE_PATH)
  30. import pt_util
  31.  
  32. # Part 2: Defining the Network
  33.  
  34. class TinyImagenetNet(nn.Module):
  35.     def __init__(self):
  36.         super(TinyImagenetNet, self).__init__()
  37.         # TODO define the layers\n",
  38.  
  39.         # I think there are 3 input channels for 3 colors\n",
  40.         # Lets try doing a combination of resnet nad darknet\n",
  41.  
  42.         # TODO define the layers\n",
  43.  
  44.         # I think there are 3 input channels for 3 colors\n",
  45.         # Lets try doing a combination of resnet nad darknet\n",
  46.         # now just using a sample of darknet\n",
  47.         self.conv4 = nn.Conv2d(3, 128, kernel_size=3, stride=1)
  48.         self.conv5 = nn.Conv2d(128, 64, kernel_size=1, stride=1)
  49.         self.conv6 = nn.Conv2d(64, 128, kernel_size=3, stride=1)
  50.         self.pool7 = nn.MaxPool2d(kernel_size=2, stride=2, padding=0)
  51.         self.conv8 = nn.Conv2d(128, 256, kernel_size=3, stride=1)
  52.         self.conv9 = nn.Conv2d(256, 128, kernel_size=1, stride=1)
  53.         self.conv10 = nn.Conv2d(128, 256, kernel_size=3, stride=1)
  54.         self.pool11 = nn.MaxPool2d(kernel_size=2, stride=2, padding=0)
  55.         self.conv16 = nn.Conv2d(256, 512, kernel_size=3, stride=1)
  56.         self.pool17 = nn.MaxPool2d(kernel_size=2, stride=2, padding=0)
  57.         self.conv18 = nn.Conv2d(512, 1024, kernel_size=3, padding=0)
  58.         self.conv19 = nn.Conv2d(1024, 2048, kernel_size=3, padding=0)
  59.         self.pool20 = nn.MaxPool2d(kernel_size=2, stride=2, padding=0)
  60.         self.avg4b4 = nn.AvgPool2d(kernel_size=5, stride=4, padding=0)
  61.         self.fc = nn.Linear(2048*2*2, 4)
  62.  
  63.         # use for save_best_model\n",
  64.         self.best = 0.0
  65.  
  66.     def forward(self, x):  # TODO define the forward pass\n",
  67.         # raise NotImplementedError('Need to define the forward pass')\n",
  68.         # dimms with batch size 10\n",
  69.         x = x  # [10, 3, 256, 256]\n",
  70.         x = F.relu(self.conv4(x))  # [10, 128, 254, 254]\n",
  71.         x = F.relu(self.conv5(x))  # [10, 64, 254, 254]\n",
  72.         x = F.relu(self.conv6(x))  # [10, 128, 252, 252]\n",
  73.         x = self.pool7(x)  # [10, 128, 126, 126]\n",
  74.         x = F.relu(self.conv8(x))  # [10, 256, 124, 124]\n",
  75.         x = F.relu(self.conv9(x))  # [10, 128, 124, 124]\n",
  76.         x = F.relu(self.conv10(x))  # [10, 256, 122, 122]\n",
  77.         x = self.pool11(x)  # [10, 256, 61, 61]\n",
  78.         x = F.relu(self.conv16(x))  # [10, 512, 59, 59]\n",
  79.         x = self.pool17(x)  # [10, 512, 29, 29]\n",
  80.         x = F.relu(self.conv18(x)) # [10, 1024, 29, 29]
  81.         x = F.relu(self.conv19(x)) # [10, 2048, 25, 25]
  82.         #print("yee{}".format(x.shape))
  83.         x= self.pool20(x) # [10, 2048, 12, 12]
  84.         #print("fee{}".format(x.shape))
  85.         x = self.avg4b4(x)  # [10, 2048, 2, 2]\n",
  86.         #print("zee{}".format(x.shape))
  87.         x = x.view(-1, 2048*2*2)
  88.  
  89.         x = self.fc(x)
  90.         return F.relu(x)
  91.  
  92.     def loss(self, prediction, label, reduction='elementwise_mean'):
  93.         _, indecies = torch.max(label, 1)
  94.         loss_val = F.cross_entropy(prediction, indecies, reduction=reduction)
  95.         return loss_val
  96.  
  97.     def save_model(self, file_path, num_to_keep=1):
  98.         pt_util.save(self, file_path, num_to_keep)
  99.  
  100.     def save_best_model(self, accuracy, file_path, num_to_keep=1):
  101.         # TODO save the model if it is the best
  102.         if (accuracy > self.best):
  103.             self.best = accuracy;
  104.             self.save_model(file_path, num_to_keep)
  105.         else:
  106.             print('Old model was better new: ' + str(accuracy) + 'old: ' + str(self.best))
  107.         # raise NotImplementedError('Need to implement save_best_model')
  108.  
  109.     def load_model(self, file_path):
  110.         pt_util.restore(self, file_path)
  111.  
  112.     def load_last_model(self, dir_path):
  113.         return pt_util.restore_latest(self, dir_path)
  114.  
  115.  
  116. import time
  117. def train(model, device, train_loader, optimizer, epoch, log_interval):
  118.     model.train()
  119.     for batch_idx, (data, label) in enumerate(train_loader):
  120.         data, label = data.to(device), label.to(device)
  121.         optimizer.zero_grad()
  122.         output = model(data)
  123.         loss = model.loss(output, label)
  124.         loss.backward()
  125.         optimizer.step()
  126.         if batch_idx % log_interval == 0:
  127.             print('{} Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
  128.                 time.ctime(time.time()),
  129.                 epoch, batch_idx * len(data), len(train_loader.dataset),
  130.                 100. * batch_idx / len(train_loader), loss.item()))
  131.  
  132. def test(model, device, test_loader, return_images=False, log_interval=None):
  133.     model.eval()
  134.     test_loss = 0
  135.     correct = 0
  136.  
  137.     correct_images = []
  138.     correct_values = []
  139.  
  140.     error_images = []
  141.     predicted_values = []
  142.     gt_values = []
  143.     with torch.no_grad():
  144.         for batch_idx, (data, label) in enumerate(test_loader):
  145.             data, label = data.to(device), label.to(device)
  146.             output = model(data)
  147.             test_loss_on = model.loss(output, label, reduction='sum').item()
  148.             test_loss += test_loss_on
  149.             pred = output.max(1)[1]
  150.             lbl = label.max(1)[1]
  151.             correct_mask = pred.eq(lbl)
  152.             num_correct = correct_mask.sum().item()
  153.             correct += num_correct
  154.             if return_images:
  155.                 if num_correct > 0:
  156.                     correct_images.append(data[correct_mask, ...].data.cpu().numpy())
  157.                     correct_value_data = label[correct_mask].data.cpu().numpy()[:, 0]
  158.                     correct_values.append(correct_value_data)
  159.                 if num_correct < len(label):
  160.                     error_data = data[~correct_mask, ...].data.cpu().numpy()
  161.                     error_images.append(error_data)
  162.                     predicted_value_data = pred[~correct_mask].data.cpu().numpy()
  163.                     predicted_values.append(predicted_value_data)
  164.                     gt_value_data = label[~correct_mask].data.cpu().numpy()[:, 0]
  165.                     gt_values.append(gt_value_data)
  166.             if log_interval is not None and batch_idx % log_interval == 0:
  167.                 print('{} Test: [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
  168.                     time.ctime(time.time()),
  169.                     batch_idx * len(data), len(test_loader.dataset),
  170.                     100. * batch_idx / len(test_loader), test_loss_on))
  171.     if return_images:
  172.         correct_images = np.concatenate(correct_images, axis=0)
  173.         error_images = np.concatenate(error_images, axis=0)
  174.         predicted_values = np.concatenate(predicted_values, axis=0)
  175.         correct_values = np.concatenate(correct_values, axis=0)
  176.         gt_values = np.concatenate(gt_values, axis=0)
  177.  
  178.     test_loss /= len(test_loader.dataset)
  179.     test_accuracy = 100. * correct / len(test_loader.dataset)
  180.  
  181.     print('\nTest set: Average loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)\n'.format(
  182.         test_loss, correct, len(test_loader.dataset), test_accuracy))
  183.     if return_images:
  184.         return correct_images, correct_values, error_images, predicted_values, gt_values, test_accuracy
  185.  
  186. # Part 3: Loading Data
  187.  
  188. # Data loader
  189. class H5Dataset(torch.utils.data.Dataset):
  190.     def __init__(self, h5_file, transform=None):
  191.         # Implement data loading.
  192.         f = h5py.File(h5_file, 'r')
  193.         self.yee = f.filename
  194.         self.images = f['images'].value
  195.         self.labels = f['labels'].value
  196.  
  197.         # self.labels = np.array(self.hfile['labels'])
  198.         self.transform = transform
  199.         # raise NotImplementedError('Need to implement the data loading')
  200.         f.close()
  201.  
  202.     def __len__(self):
  203.         # Implement the length function
  204.         return self.images.shape[0]
  205.  
  206.     def __getitem__(self, idx):
  207.         # implement the getitem function
  208.         # You should return a tuple of:
  209.         #    a torch tensor containing single image in CxHxW format and
  210.         #    the label as a single tensor scalar
  211.         # raise NotImplementedError('Need to implement the data loading')
  212.  
  213.         # best practice f = h5py.File() and then f["image"][idx] in here
  214.  
  215.         # w and h may be swapped, but channel is first
  216.         data = torch.FloatTensor(self.images[idx]).permute(2, 0, 1)
  217.         label = torch.LongTensor(self.labels[idx])
  218.         # if (self.yee.__contains__("train")):
  219.         #     print("fetching{}".format(idx))
  220.         #     print("data{}".format(data))
  221.         #     print("label{}".format(label))
  222.         if self.transform:
  223.             data = self.transform(data)
  224.         return (data, label)
  225.  
  226. # Part 4: dataset augmentation
  227.  
  228. import torchvision
  229.  
  230. train_transforms = transforms.Compose([
  231.     transforms.ToPILImage(),
  232.     transforms.RandomHorizontalFlip(),
  233.     transforms.ColorJitter(brightness=0.1, contrast=0.1, saturation=0.1, hue=0.0),
  234.     transforms.RandomResizedCrop(256, scale=(0.4, 0.8), ratio=(0.75, 0.75)),
  235.     # transforms.Resize((10,10)),
  236.     transforms.ToTensor()
  237. ]) #transforms.Compose([transforms.ColorJitter(brightness=0.1, contrast=0.1, saturation=0.1, hue=0.1), transforms.ToTensor(),])
  238.  
  239. test_transforms = transforms.Compose([
  240.     transforms.ToPILImage(),
  241.     transforms.CenterCrop(256),
  242.     transforms.ToTensor()
  243. ]) #transforms.Compose([transforms.ColorJitter(brightness=0.1, contrast=0.1, saturation=0.1, hue=0.1), transforms.ToTensor(),])
  244.  
  245.  
  246. # test_transforms = None
  247.  
  248. data_train = H5Dataset(DATA_PATH + 'cartoon_train.h5', transform=train_transforms)
  249. data_test = H5Dataset(DATA_PATH + 'cartoon_test.h5', transform=test_transforms)
  250.  
  251. # Part 5: Training the network
  252.  
  253. # Play around with these constants, you may find a better setting.
  254. BATCH_SIZE = 10
  255. # BATCH_SIZE = 1024 this caused an error
  256. TEST_BATCH_SIZE = 10
  257. EPOCHS = 10
  258. # LEARNING_RATE = 0.01
  259. LEARNING_RATE = 0.02
  260. MOMENTUM = 0.9
  261. USE_CUDA = True
  262. PRINT_INTERVAL = 100
  263. WEIGHT_DECAY = 0.0005
  264. LOG_PATH = DATA_PATH + 'log.pkl'
  265.  
  266. # Now the actual training code
  267. use_cuda = USE_CUDA and torch.cuda.is_available()
  268.  
  269. device = torch.device("cuda" if use_cuda else "cpu")
  270. print('Using device', device)
  271. import multiprocessing
  272.  
  273. print('num cpus:', multiprocessing.cpu_count())
  274.  
  275. kwargs = {'num_workers': multiprocessing.cpu_count(),
  276.           'pin_memory': True} if use_cuda else {}
  277.  
  278. # class_names = [line.strip().split(', ') for line in open(DATA_PATH + 'class_names.txt')]
  279. class_names = ["one", "two", "three", "four"]
  280. name_to_class = {line[1]: line[0] for line in class_names}
  281. class_names = [line[1] for line in class_names]
  282.  
  283. train_loader = torch.utils.data.DataLoader(data_train, batch_size=BATCH_SIZE,
  284.                                            shuffle=True, **kwargs)
  285. test_loader = torch.utils.data.DataLoader(data_test, batch_size=TEST_BATCH_SIZE,
  286.                                           shuffle=False, **kwargs)
  287.  
  288. model = TinyImagenetNet().to(device)
  289. optimizer = optim.SGD(model.parameters(), lr=LEARNING_RATE, momentum=MOMENTUM, weight_decay=WEIGHT_DECAY)
  290. start_epoch = model.load_last_model(DATA_PATH + 'checkpoints')
  291.  
  292. # You may want to define another default for your log data depending on how you save it.
  293. log_data = pt_util.read_log(LOG_PATH, [])
  294.  
  295. correct_images, correct_val, error_images, predicted_val, gt_val, test_accuracy = test(model, device, test_loader, True, 10)
  296. correct_images = pt_util.to_scaled_uint8(correct_images.transpose(0, 2, 3, 1))
  297. error_images = pt_util.to_scaled_uint8(error_images.transpose(0, 2, 3, 1))
  298. pt_util.show_images(correct_images, ['correct: %s' % class_names[aa] for aa in correct_val])
  299. pt_util.show_images(error_images, ['pred: %s, actual: %s' % (class_names[aa], class_names[bb]) for aa, bb in
  300.                                    zip(predicted_val, gt_val)])
  301.  
  302. try:
  303.     for epoch in range(start_epoch, EPOCHS + 1):
  304.         train(model, device, train_loader, optimizer, epoch, PRINT_INTERVAL)
  305.         correct_images, correct_val, error_images, predicted_val, gt_val, test_accuracy = test(model, device,
  306.                                                                                                test_loader, True, 10)
  307.         # TODO define other things to do at the end of each loop like logging and saving the best model.
  308.         model.save_best_model(test_accuracy, DATA_PATH + 'checkpoints/best.pt')
  309.  
  310.  
  311. except KeyboardInterrupt as ke:
  312.     print('Interrupted')
  313. except:
  314.     import traceback
  315.  
  316.     traceback.print_exc()
  317. finally:
  318.     # Always save the most recent model, but don't delete any existing ones.
  319.     model.save_model(DATA_PATH + 'checkpoints/%03d.pt' % epoch, 0)
  320.  
  321.     # Show some current correct/incorrect images.
  322.     correct_images = pt_util.to_scaled_uint8(correct_images.transpose(0, 2, 3, 1))
  323.     error_images = pt_util.to_scaled_uint8(error_images.transpose(0, 2, 3, 1))
  324.     pt_util.show_images(correct_images, ['correct: %s' % class_names[aa] for aa in correct_val])
  325.     pt_util.show_images(error_images, ['pred: %s, actual: %s' % (class_names[aa], class_names[bb]) for aa, bb in
  326.                                        zip(predicted_val, gt_val)])
RAW Paste Data
We use cookies for various purposes including analytics. By continuing to use Pastebin, you agree to our use of cookies as described in the Cookies Policy. OK, I Understand
 
Top