Advertisement
Guest User

Untitled

a guest
Apr 8th, 2020
197
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
Python 3.35 KB | None | 0 0
  1. # -*- coding: utf-8 -*-
  2. """
  3. Options
  4. """
  5.  
  6. import argparse
  7. import os
  8. import torch
  9.  
  10.  
  11. class Options():
  12.     """Options class
  13.  
  14.    Returns:
  15.        [argparse]: argparse containing train and test options
  16.    """
  17.  
  18.     def __init__(self):
  19.  
  20.         self.parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter)
  21.  
  22.         # Base
  23.         self.parser.add_argument('--dataset', default='folder', help='folder | cifar10 | mnist ')
  24.         self.parser.add_argument('--dataroot', default='', help='path to dataset')        
  25.         self.parser.add_argument('--path', default='', help='path to the folder or image to be predicted.')
  26.         self.parser.add_argument('--batchsize', type=int, default=8, help='input batch size')
  27.         self.parser.add_argument('--workers', type=int, help='number of data loading workers', default=8)
  28.         self.parser.add_argument('--droplast', action='store_true', default=True, help='Drop last batch size.')
  29.         self.parser.add_argument('--isize', type=int, default=128, help='input image size.')
  30.         self.parser.add_argument('--name', type=str, default='experiment_name', help='name of the experiment')
  31.  
  32.     def parse(self):
  33.         """
  34.        Parse Arguments.
  35.        """
  36.  
  37.         self.opt = self.parser.parse_args()
  38.  
  39.  
  40.         args = vars(self.opt)
  41.  
  42.  
  43.         # save to the disk
  44.         if self.opt.name == 'experiment_name':
  45.             self.opt.name = "%s/%s" % (self.opt.model, self.opt.dataset)
  46.         expr_dir = os.path.join(self.opt.outf, self.opt.name, 'train')
  47.         test_dir = os.path.join(self.opt.outf, self.opt.name, 'test')
  48.  
  49.         if not os.path.isdir(expr_dir):
  50.             os.makedirs(expr_dir)
  51.         if not os.path.isdir(test_dir):
  52.             os.makedirs(test_dir)
  53.  
  54.         file_name = os.path.join(expr_dir, 'opt.txt')
  55.         with open(file_name, 'wt') as opt_file:
  56.             opt_file.write('------------ Options -------------\n')
  57.             for k, v in sorted(args.items()):
  58.                 opt_file.write('%s: %s\n' % (str(k), str(v)))
  59.             opt_file.write('-------------- End ----------------\n')
  60.         return self.opt
  61.  
  62.  
  63. # -*- coding: utf-8 -*-
  64. """
  65. Dataloader
  66. """
  67.  
  68. import os
  69. from torchvision import transforms
  70. from torch.utils.data import DataLoader
  71. from torchvision.datasets import ImageFolder
  72.  
  73.  
  74. class Data:
  75.     """
  76.    Dataloader containing train and validation sets.
  77.    """
  78.     def __init__(self, train, valid):
  79.         self.train = train
  80.         self.valid = valid
  81.  
  82. ##
  83. def load_data(opt):
  84.     """ Load Data
  85.  
  86.    Args:
  87.        opt ([type]): Argument Parser
  88.  
  89.  
  90.    Returns:
  91.        [type]: dataloader
  92.    """
  93.  
  94.  
  95.     transform = transforms.Compose([transforms.Resize(opt.isize),
  96.                                         transforms.CenterCrop(opt.isize),
  97.                                         transforms.ToTensor(),
  98.                                         transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)), ])
  99.  
  100.     train_ds = ImageFolder(os.path.join(opt.dataroot, 'train'), transform)
  101.     valid_ds = ImageFolder(os.path.join(opt.dataroot, 'test'), transform)
  102.  
  103.     ## DATALOADER
  104.     train_dl = DataLoader(dataset=train_ds, batch_size=opt.batchsize, shuffle=True, drop_last=True)
  105.     valid_dl = DataLoader(dataset=valid_ds, batch_size=opt.batchsize, shuffle=False, drop_last=False)
  106.  
  107.     return Data(train_dl, valid_dl)
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement