Advertisement
Not a member of Pastebin yet?
Sign Up,
it unlocks many cool features!
- # -*- coding: utf-8 -*-
- """
- Options
- """
- import argparse
- import os
- import torch
- class Options():
- """Options class
- Returns:
- [argparse]: argparse containing train and test options
- """
- def __init__(self):
- self.parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter)
- # Base
- self.parser.add_argument('--dataset', default='folder', help='folder | cifar10 | mnist ')
- self.parser.add_argument('--dataroot', default='', help='path to dataset')
- self.parser.add_argument('--path', default='', help='path to the folder or image to be predicted.')
- self.parser.add_argument('--batchsize', type=int, default=8, help='input batch size')
- self.parser.add_argument('--workers', type=int, help='number of data loading workers', default=8)
- self.parser.add_argument('--droplast', action='store_true', default=True, help='Drop last batch size.')
- self.parser.add_argument('--isize', type=int, default=128, help='input image size.')
- self.parser.add_argument('--name', type=str, default='experiment_name', help='name of the experiment')
- def parse(self):
- """
- Parse Arguments.
- """
- self.opt = self.parser.parse_args()
- args = vars(self.opt)
- # save to the disk
- if self.opt.name == 'experiment_name':
- self.opt.name = "%s/%s" % (self.opt.model, self.opt.dataset)
- expr_dir = os.path.join(self.opt.outf, self.opt.name, 'train')
- test_dir = os.path.join(self.opt.outf, self.opt.name, 'test')
- if not os.path.isdir(expr_dir):
- os.makedirs(expr_dir)
- if not os.path.isdir(test_dir):
- os.makedirs(test_dir)
- file_name = os.path.join(expr_dir, 'opt.txt')
- with open(file_name, 'wt') as opt_file:
- opt_file.write('------------ Options -------------\n')
- for k, v in sorted(args.items()):
- opt_file.write('%s: %s\n' % (str(k), str(v)))
- opt_file.write('-------------- End ----------------\n')
- return self.opt
- # -*- coding: utf-8 -*-
- """
- Dataloader
- """
- import os
- from torchvision import transforms
- from torch.utils.data import DataLoader
- from torchvision.datasets import ImageFolder
- class Data:
- """
- Dataloader containing train and validation sets.
- """
- def __init__(self, train, valid):
- self.train = train
- self.valid = valid
- ##
- def load_data(opt):
- """ Load Data
- Args:
- opt ([type]): Argument Parser
- Returns:
- [type]: dataloader
- """
- transform = transforms.Compose([transforms.Resize(opt.isize),
- transforms.CenterCrop(opt.isize),
- transforms.ToTensor(),
- transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)), ])
- train_ds = ImageFolder(os.path.join(opt.dataroot, 'train'), transform)
- valid_ds = ImageFolder(os.path.join(opt.dataroot, 'test'), transform)
- ## DATALOADER
- train_dl = DataLoader(dataset=train_ds, batch_size=opt.batchsize, shuffle=True, drop_last=True)
- valid_dl = DataLoader(dataset=valid_ds, batch_size=opt.batchsize, shuffle=False, drop_last=False)
- return Data(train_dl, valid_dl)
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement