Guest User

Untitled

a guest
Jan 16th, 2019
86
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
text 3.25 KB | None | 0 0
  1. import sys
  2. import collections
  3. import torch
  4. import torch.nn as nn
  5. import torch.nn.functional as F
  6. import torchvision
  7. import torchvision.transforms as transforms
  8.  
  9. import catalyst
  10. from catalyst.dl.callbacks import (
  11. ClassificationLossCallback,
  12. Logger, TensorboardLogger,
  13. OptimizerCallback, SchedulerCallback, CheckpointCallback,
  14. PrecisionCallback, OneCycleLR)
  15.  
  16. from catalyst.dl.runner import ClassificationRunner
  17.  
  18.  
  19. class Net(nn.Module):
  20. def __init__(self):
  21. super(Net, self).__init__()
  22. self.conv1 = nn.Conv2d(3, 6, 5)
  23. self.pool = nn.MaxPool2d(2, 2)
  24. self.conv2 = nn.Conv2d(6, 16, 5)
  25. self.fc1 = nn.Linear(16 * 5 * 5, 120)
  26. self.fc2 = nn.Linear(120, 84)
  27. self.fc3 = nn.Linear(84, 10)
  28.  
  29. def forward(self, x):
  30. x = self.pool(F.relu(self.conv1(x)))
  31. x = self.pool(F.relu(self.conv2(x)))
  32. x = x.view(-1, 16 * 5 * 5)
  33. x = F.relu(self.fc1(x))
  34. x = F.relu(self.fc2(x))
  35. x = self.fc3(x)
  36. return x
  37.  
  38.  
  39. def main():
  40. print('Python version', sys.version)
  41. print('Catalyst version:', catalyst.__version__)
  42.  
  43. bs = 32
  44. n_workers = 0
  45.  
  46. data_transform = transforms.Compose([
  47. transforms.ToTensor(),
  48. transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])
  49.  
  50. loaders = collections.OrderedDict()
  51.  
  52. trainset = torchvision.datasets.CIFAR10(
  53. root='./data', train=True,
  54. download=True, transform=data_transform)
  55. trainloader = torch.utils.data.DataLoader(
  56. trainset, batch_size=bs,
  57. shuffle=True, num_workers=n_workers)
  58.  
  59. testset = torchvision.datasets.CIFAR10(
  60. root='./data', train=False,
  61. download=True, transform=data_transform)
  62. testloader = torch.utils.data.DataLoader(
  63. testset, batch_size=bs,
  64. shuffle=False, num_workers=n_workers)
  65.  
  66. loaders["train"] = trainloader
  67. loaders["valid"] = testloader
  68.  
  69. model = Net().cuda()
  70. criterion = nn.CrossEntropyLoss()
  71. optimizer = torch.optim.Adam(model.parameters())
  72. # scheduler = None # for OneCycle usage
  73. # scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, milestones=[3, 8], gamma=0.3)
  74. scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, factor=0.5, patience=2, verbose=True)
  75.  
  76. # the only tricky part
  77. n_epochs = 10
  78. logdir = "./logs/cifar_simple_notebook"
  79.  
  80. callbacks = collections.OrderedDict()
  81.  
  82. callbacks["loss"] = ClassificationLossCallback()
  83. callbacks["optimizer"] = OptimizerCallback()
  84. callbacks["precision"] = PrecisionCallback(
  85. precision_args=[1, 3, 5])
  86.  
  87. # OneCylce custom scheduler callback
  88. # callbacks["scheduler"] = OneCycleLR(
  89. # cycle_len=n_epochs,
  90. # div=3, cut_div=4, momentum_range=(0.95, 0.85))
  91.  
  92. # Pytorch scheduler callback
  93. callbacks["scheduler"] = SchedulerCallback(
  94. reduce_metric="precision01")
  95.  
  96. callbacks["saver"] = CheckpointCallback()
  97. callbacks["logger"] = Logger()
  98. callbacks["tflogger"] = TensorboardLogger()
  99.  
  100. runner = ClassificationRunner(
  101. model=model,
  102. criterion=criterion,
  103. optimizer=optimizer,
  104. scheduler=scheduler)
  105. runner.train(
  106. loaders=loaders,
  107. callbacks=callbacks,
  108. logdir=logdir,
  109. epochs=n_epochs, verbose=True)
  110.  
  111.  
  112. if __name__ == '__main__':
  113. main()
Add Comment
Please, Sign In to add comment