Guest User

Untitled

a guest
Sep 18th, 2018
76
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
text 4.77 KB | None | 0 0
  1. {
  2. "cells": [
  3. {
  4. "cell_type": "markdown",
  5. "metadata": {},
  6. "source": [
  7. "# Data"
  8. ]
  9. },
  10. {
  11. "cell_type": "code",
  12. "execution_count": null,
  13. "metadata": {},
  14. "outputs": [],
  15. "source": [
  16. "import collections\n",
  17. "import torch\n",
  18. "import torchvision\n",
  19. "import torchvision.transforms as transforms\n",
  20. "\n",
  21. "\n",
  22. "bs = 32\n",
  23. "n_workers = 4\n",
  24. "\n",
  25. "data_transform = transforms.Compose([\n",
  26. " transforms.ToTensor(),\n",
  27. " transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])\n",
  28. "\n",
  29. "\n",
  30. "loaders = collections.OrderedDict()\n",
  31. "\n",
  32. "trainset = torchvision.datasets.CIFAR10(\n",
  33. " root='./data', train=True,\n",
  34. " download=True, transform=data_transform)\n",
  35. "trainloader = torch.utils.data.DataLoader(\n",
  36. " trainset, batch_size=bs,\n",
  37. " shuffle=True, num_workers=n_workers)\n",
  38. "\n",
  39. "testset = torchvision.datasets.CIFAR10(\n",
  40. " root='./data', train=False,\n",
  41. " download=True, transform=data_transform)\n",
  42. "testloader = torch.utils.data.DataLoader(\n",
  43. " testset, batch_size=bs,\n",
  44. " shuffle=False, num_workers=n_workers)\n",
  45. "\n",
  46. "loaders[\"train\"] = trainloader\n",
  47. "loaders[\"valid\"] = testloader"
  48. ]
  49. },
  50. {
  51. "cell_type": "markdown",
  52. "metadata": {},
  53. "source": [
  54. "# Model"
  55. ]
  56. },
  57. {
  58. "cell_type": "code",
  59. "execution_count": null,
  60. "metadata": {},
  61. "outputs": [],
  62. "source": [
  63. "import collections\n",
  64. "import torch.nn as nn\n",
  65. "import torch.nn.functional as F\n",
  66. "\n",
  67. "class Net(nn.Module):\n",
  68. " def __init__(self):\n",
  69. " super(Net, self).__init__()\n",
  70. " self.conv1 = nn.Conv2d(3, 6, 5)\n",
  71. " self.pool = nn.MaxPool2d(2, 2)\n",
  72. " self.conv2 = nn.Conv2d(6, 16, 5)\n",
  73. " self.fc1 = nn.Linear(16 * 5 * 5, 120)\n",
  74. " self.fc2 = nn.Linear(120, 84)\n",
  75. " self.fc3 = nn.Linear(84, 10)\n",
  76. "\n",
  77. " def forward(self, x):\n",
  78. " x = self.pool(F.relu(self.conv1(x)))\n",
  79. " x = self.pool(F.relu(self.conv2(x)))\n",
  80. " x = x.view(-1, 16 * 5 * 5)\n",
  81. " x = F.relu(self.fc1(x))\n",
  82. " x = F.relu(self.fc2(x))\n",
  83. " x = self.fc3(x)\n",
  84. " return x"
  85. ]
  86. },
  87. {
  88. "cell_type": "markdown",
  89. "metadata": {},
  90. "source": [
  91. "# Model, criterion, optimizer"
  92. ]
  93. },
  94. {
  95. "cell_type": "code",
  96. "execution_count": null,
  97. "metadata": {},
  98. "outputs": [],
  99. "source": [
  100. "model = Net()\n",
  101. "criterion = nn.CrossEntropyLoss()\n",
  102. "optimizer = torch.optim.Adam(model.parameters())"
  103. ]
  104. },
  105. {
  106. "cell_type": "markdown",
  107. "metadata": {},
  108. "source": [
  109. "# Callbacks"
  110. ]
  111. },
  112. {
  113. "cell_type": "code",
  114. "execution_count": null,
  115. "metadata": {},
  116. "outputs": [],
  117. "source": [
  118. "from common.dl.callbacks import (\n",
  119. " ClassificationLossCallback, LoggerCallback, PrecisionCallback,\n",
  120. " OptimizerCallback, CheckpointCallback, OneCycleLR)\n",
  121. "\n",
  122. "# the only tricky part\n",
  123. "n_epochs = 10\n",
  124. "logdir = \"./logs/sample\"\n",
  125. "\n",
  126. "callbacks = collections.OrderedDict()\n",
  127. "\n",
  128. "callbacks[\"loss\"] = ClassificationLossCallback()\n",
  129. "callbacks[\"optimizer\"] = OptimizerCallback()\n",
  130. "callbacks[\"one-cycle\"] = OneCycleLR(\n",
  131. " cycle_len=n_epochs,\n",
  132. " div=3, cut_div=4, momentum_range=(0.95, 0.85))\n",
  133. "callbacks[\"precision\"] = PrecisionCallback(\n",
  134. " precision_args=[1, 3, 5])\n",
  135. "callbacks[\"logger\"] = LoggerCallback()\n",
  136. "callbacks[\"saver\"] = CheckpointCallback()"
  137. ]
  138. },
  139. {
  140. "cell_type": "markdown",
  141. "metadata": {},
  142. "source": [
  143. "# Train"
  144. ]
  145. },
  146. {
  147. "cell_type": "code",
  148. "execution_count": null,
  149. "metadata": {},
  150. "outputs": [],
  151. "source": [
  152. "from common.dl.runner import ClassificationRunner\n",
  153. "\n",
  154. "runner = ClassificationRunner(\n",
  155. " model=model, \n",
  156. " criterion=criterion, \n",
  157. " optimizer=optimizer)\n",
  158. "runner.train_stage(\n",
  159. " loaders=loaders, \n",
  160. " callbacks=callbacks, \n",
  161. " logdir=logdir,\n",
  162. " epochs=n_epochs, verbose=True)"
  163. ]
  164. },
  165. {
  166. "cell_type": "code",
  167. "execution_count": null,
  168. "metadata": {},
  169. "outputs": [],
  170. "source": []
  171. }
  172. ],
  173. "metadata": {
  174. "kernelspec": {
  175. "display_name": "Python [conda env:py36]",
  176. "language": "python",
  177. "name": "conda-env-py36-py"
  178. },
  179. "language_info": {
  180. "codemirror_mode": {
  181. "name": "ipython",
  182. "version": 3
  183. },
  184. "file_extension": ".py",
  185. "mimetype": "text/x-python",
  186. "name": "python",
  187. "nbconvert_exporter": "python",
  188. "pygments_lexer": "ipython3",
  189. "version": "3.6.4"
  190. }
  191. },
  192. "nbformat": 4,
  193. "nbformat_minor": 2
  194. }
Add Comment
Please, Sign In to add comment