Need a unique gift idea?
A Pastebin account makes a great Christmas gift
SHARE
TWEET

Untitled

a guest Sep 18th, 2018 53 Never
Upgrade to PRO!
ENDING IN00days00hours00mins00secs
 
  1. {
  2.  "cells": [
  3.   {
  4.    "cell_type": "markdown",
  5.    "metadata": {},
  6.    "source": [
  7.     "# Data"
  8.    ]
  9.   },
  10.   {
  11.    "cell_type": "code",
  12.    "execution_count": null,
  13.    "metadata": {},
  14.    "outputs": [],
  15.    "source": [
  16.     "import collections\n",
  17.     "import torch\n",
  18.     "import torchvision\n",
  19.     "import torchvision.transforms as transforms\n",
  20.     "\n",
  21.     "\n",
  22.     "bs = 32\n",
  23.     "n_workers = 4\n",
  24.     "\n",
  25.     "data_transform = transforms.Compose([\n",
  26.     "    transforms.ToTensor(),\n",
  27.     "    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])\n",
  28.     "\n",
  29.     "\n",
  30.     "loaders = collections.OrderedDict()\n",
  31.     "\n",
  32.     "trainset = torchvision.datasets.CIFAR10(\n",
  33.     "    root='./data', train=True,\n",
  34.     "    download=True, transform=data_transform)\n",
  35.     "trainloader = torch.utils.data.DataLoader(\n",
  36.     "    trainset, batch_size=bs,\n",
  37.     "    shuffle=True, num_workers=n_workers)\n",
  38.     "\n",
  39.     "testset = torchvision.datasets.CIFAR10(\n",
  40.     "    root='./data', train=False,\n",
  41.     "    download=True, transform=data_transform)\n",
  42.     "testloader = torch.utils.data.DataLoader(\n",
  43.     "    testset, batch_size=bs,\n",
  44.     "    shuffle=False, num_workers=n_workers)\n",
  45.     "\n",
  46.     "loaders[\"train\"] = trainloader\n",
  47.     "loaders[\"valid\"] = testloader"
  48.    ]
  49.   },
  50.   {
  51.    "cell_type": "markdown",
  52.    "metadata": {},
  53.    "source": [
  54.     "# Model"
  55.    ]
  56.   },
  57.   {
  58.    "cell_type": "code",
  59.    "execution_count": null,
  60.    "metadata": {},
  61.    "outputs": [],
  62.    "source": [
  63.     "import collections\n",
  64.     "import torch.nn as nn\n",
  65.     "import torch.nn.functional as F\n",
  66.     "\n",
  67.     "class Net(nn.Module):\n",
  68.     "    def __init__(self):\n",
  69.     "        super(Net, self).__init__()\n",
  70.     "        self.conv1 = nn.Conv2d(3, 6, 5)\n",
  71.     "        self.pool = nn.MaxPool2d(2, 2)\n",
  72.     "        self.conv2 = nn.Conv2d(6, 16, 5)\n",
  73.     "        self.fc1 = nn.Linear(16 * 5 * 5, 120)\n",
  74.     "        self.fc2 = nn.Linear(120, 84)\n",
  75.     "        self.fc3 = nn.Linear(84, 10)\n",
  76.     "\n",
  77.     "    def forward(self, x):\n",
  78.     "        x = self.pool(F.relu(self.conv1(x)))\n",
  79.     "        x = self.pool(F.relu(self.conv2(x)))\n",
  80.     "        x = x.view(-1, 16 * 5 * 5)\n",
  81.     "        x = F.relu(self.fc1(x))\n",
  82.     "        x = F.relu(self.fc2(x))\n",
  83.     "        x = self.fc3(x)\n",
  84.     "        return x"
  85.    ]
  86.   },
  87.   {
  88.    "cell_type": "markdown",
  89.    "metadata": {},
  90.    "source": [
  91.     "# Model, criterion, optimizer"
  92.    ]
  93.   },
  94.   {
  95.    "cell_type": "code",
  96.    "execution_count": null,
  97.    "metadata": {},
  98.    "outputs": [],
  99.    "source": [
  100.     "model = Net()\n",
  101.     "criterion = nn.CrossEntropyLoss()\n",
  102.     "optimizer = torch.optim.Adam(model.parameters())"
  103.    ]
  104.   },
  105.   {
  106.    "cell_type": "markdown",
  107.    "metadata": {},
  108.    "source": [
  109.     "# Callbacks"
  110.    ]
  111.   },
  112.   {
  113.    "cell_type": "code",
  114.    "execution_count": null,
  115.    "metadata": {},
  116.    "outputs": [],
  117.    "source": [
  118.     "from common.dl.callbacks import (\n",
  119.     "    ClassificationLossCallback, LoggerCallback, PrecisionCallback,\n",
  120.     "    OptimizerCallback, CheckpointCallback, OneCycleLR)\n",
  121.     "\n",
  122.     "# the only tricky part\n",
  123.     "n_epochs = 10\n",
  124.     "logdir = \"./logs/sample\"\n",
  125.     "\n",
  126.     "callbacks = collections.OrderedDict()\n",
  127.     "\n",
  128.     "callbacks[\"loss\"] = ClassificationLossCallback()\n",
  129.     "callbacks[\"optimizer\"] = OptimizerCallback()\n",
  130.     "callbacks[\"one-cycle\"] = OneCycleLR(\n",
  131.     "    cycle_len=n_epochs,\n",
  132.     "    div=3, cut_div=4, momentum_range=(0.95, 0.85))\n",
  133.     "callbacks[\"precision\"] = PrecisionCallback(\n",
  134.     "    precision_args=[1, 3, 5])\n",
  135.     "callbacks[\"logger\"] = LoggerCallback()\n",
  136.     "callbacks[\"saver\"] = CheckpointCallback()"
  137.    ]
  138.   },
  139.   {
  140.    "cell_type": "markdown",
  141.    "metadata": {},
  142.    "source": [
  143.     "# Train"
  144.    ]
  145.   },
  146.   {
  147.    "cell_type": "code",
  148.    "execution_count": null,
  149.    "metadata": {},
  150.    "outputs": [],
  151.    "source": [
  152.     "from common.dl.runner import ClassificationRunner\n",
  153.     "\n",
  154.     "runner = ClassificationRunner(\n",
  155.     "    model=model, \n",
  156.     "    criterion=criterion, \n",
  157.     "    optimizer=optimizer)\n",
  158.     "runner.train_stage(\n",
  159.     "    loaders=loaders, \n",
  160.     "    callbacks=callbacks, \n",
  161.     "    logdir=logdir,\n",
  162.     "    epochs=n_epochs, verbose=True)"
  163.    ]
  164.   },
  165.   {
  166.    "cell_type": "code",
  167.    "execution_count": null,
  168.    "metadata": {},
  169.    "outputs": [],
  170.    "source": []
  171.   }
  172.  ],
  173.  "metadata": {
  174.   "kernelspec": {
  175.    "display_name": "Python [conda env:py36]",
  176.    "language": "python",
  177.    "name": "conda-env-py36-py"
  178.   },
  179.   "language_info": {
  180.    "codemirror_mode": {
  181.     "name": "ipython",
  182.     "version": 3
  183.    },
  184.    "file_extension": ".py",
  185.    "mimetype": "text/x-python",
  186.    "name": "python",
  187.    "nbconvert_exporter": "python",
  188.    "pygments_lexer": "ipython3",
  189.    "version": "3.6.4"
  190.   }
  191.  },
  192.  "nbformat": 4,
  193.  "nbformat_minor": 2
  194. }
RAW Paste Data
We use cookies for various purposes including analytics. By continuing to use Pastebin, you agree to our use of cookies as described in the Cookies Policy. OK, I Understand
 
Top