Advertisement
Not a member of Pastebin yet?
Sign Up,
it unlocks many cool features!
- from torch.optim import SGD
- from torchnet.meter import AverageValueMeter
- from torchnet.logger import VisdomPlotLogger, VisdomSaver
- from sklearn.metrics import accuracy_score
- def train_classifier(model, train_loader, test_loader, exp_name='experiment', lr=0.01, epochs=10, momentum=0.99):
- criterion = nn.CrossEntropyLoss()
- optimizer = SGD(model.parameters(), lr, momentum=momentum)
- #meters
- loss_meter = AverageValueMeter()
- acc_meter = AverageValueMeter()
- #plotters
- loss_logger = VisdomPlotLogger('line', env=exp_name, opts={'title': 'Loss', 'legend':['train','test']})
- acc_logger = VisdomPlotLogger('line', env=exp_name, opts={'title': 'Accuracy','legend':['train','test']})
- visdom_saver = VisdomSaver(envs=[exp_name])
- #device
- device = "cuda" if torch.cuda.is_available() else "cpu"
- model.to(device)
- #definiamo un dizionario contenente i loader di training e test
- loader = {
- 'train' : train_loader,
- 'test' : test_loader
- }
- for e in range(epochs):
- #iteriamo tra due modalità: train e test
- for mode in ['train','test'] :
- loss_meter.reset(); acc_meter.reset()
- model.train() if mode == 'train' else model.eval()
- with torch.set_grad_enabled(mode=='train'): #abilitiamo i gradienti solo in training
- for i, batch in enumerate(loader[mode]):
- x=batch[0].to(device) #"portiamoli sul device corretto"
- y=batch[1].to(device)
- output = model(x)
- l = criterion(output,y)
- if mode=='train':
- l.backward()
- optimizer.step()
- optimizer.zero_grad()
- acc = accuracy_score(y.to('cpu'),output.to('cpu').max(1)[1])
- n = batch[0].shape[0] #numero di elementi nel batch
- loss_meter.add(l.item()*n,n)
- acc_meter.add(acc*n,n)
- if mode=='train':
- loss_logger.log(e+(i+1)/len(loader[mode]), loss_meter.value()[0], name=mode)
- acc_logger.log(e+(i+1)/len(loader[mode]), acc_meter.value()[0], name=mode)
- loss_logger.log(e+1, loss_meter.value()[0], name=mode)
- acc_logger.log(e+1, acc_meter.value()[0], name=mode)
- torch.save(model.state_dict(),'%s-%d.pth'%(exp_name,e+1))
- return model
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement