Not a member of Pastebin yet?
Sign Up,
it unlocks many cool features!
- {
- "cells": [
- {
- "cell_type": "markdown",
- "metadata": {},
- "source": [
- "# Data"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "metadata": {},
- "outputs": [],
- "source": [
- "import collections\n",
- "import torch\n",
- "import torchvision\n",
- "import torchvision.transforms as transforms\n",
- "\n",
- "\n",
- "bs = 32\n",
- "n_workers = 4\n",
- "\n",
- "data_transform = transforms.Compose([\n",
- " transforms.ToTensor(),\n",
- " transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])\n",
- "\n",
- "\n",
- "loaders = collections.OrderedDict()\n",
- "\n",
- "trainset = torchvision.datasets.CIFAR10(\n",
- " root='./data', train=True,\n",
- " download=True, transform=data_transform)\n",
- "trainloader = torch.utils.data.DataLoader(\n",
- " trainset, batch_size=bs,\n",
- " shuffle=True, num_workers=n_workers)\n",
- "\n",
- "testset = torchvision.datasets.CIFAR10(\n",
- " root='./data', train=False,\n",
- " download=True, transform=data_transform)\n",
- "testloader = torch.utils.data.DataLoader(\n",
- " testset, batch_size=bs,\n",
- " shuffle=False, num_workers=n_workers)\n",
- "\n",
- "loaders[\"train\"] = trainloader\n",
- "loaders[\"valid\"] = testloader"
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {},
- "source": [
- "# Model"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "metadata": {},
- "outputs": [],
- "source": [
- "import collections\n",
- "import torch.nn as nn\n",
- "import torch.nn.functional as F\n",
- "\n",
- "class Net(nn.Module):\n",
- " def __init__(self):\n",
- " super(Net, self).__init__()\n",
- " self.conv1 = nn.Conv2d(3, 6, 5)\n",
- " self.pool = nn.MaxPool2d(2, 2)\n",
- " self.conv2 = nn.Conv2d(6, 16, 5)\n",
- " self.fc1 = nn.Linear(16 * 5 * 5, 120)\n",
- " self.fc2 = nn.Linear(120, 84)\n",
- " self.fc3 = nn.Linear(84, 10)\n",
- "\n",
- " def forward(self, x):\n",
- " x = self.pool(F.relu(self.conv1(x)))\n",
- " x = self.pool(F.relu(self.conv2(x)))\n",
- " x = x.view(-1, 16 * 5 * 5)\n",
- " x = F.relu(self.fc1(x))\n",
- " x = F.relu(self.fc2(x))\n",
- " x = self.fc3(x)\n",
- " return x"
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {},
- "source": [
- "# Model, criterion, optimizer"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "metadata": {},
- "outputs": [],
- "source": [
- "model = Net()\n",
- "criterion = nn.CrossEntropyLoss()\n",
- "optimizer = torch.optim.Adam(model.parameters())"
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {},
- "source": [
- "# Callbacks"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "metadata": {},
- "outputs": [],
- "source": [
- "from common.dl.callbacks import (\n",
- " ClassificationLossCallback, LoggerCallback, PrecisionCallback,\n",
- " OptimizerCallback, CheckpointCallback, OneCycleLR)\n",
- "\n",
- "# the only tricky part\n",
- "n_epochs = 10\n",
- "logdir = \"./logs/sample\"\n",
- "\n",
- "callbacks = collections.OrderedDict()\n",
- "\n",
- "callbacks[\"loss\"] = ClassificationLossCallback()\n",
- "callbacks[\"optimizer\"] = OptimizerCallback()\n",
- "callbacks[\"one-cycle\"] = OneCycleLR(\n",
- " cycle_len=n_epochs,\n",
- " div=3, cut_div=4, momentum_range=(0.95, 0.85))\n",
- "callbacks[\"precision\"] = PrecisionCallback(\n",
- " precision_args=[1, 3, 5])\n",
- "callbacks[\"logger\"] = LoggerCallback()\n",
- "callbacks[\"saver\"] = CheckpointCallback()"
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {},
- "source": [
- "# Train"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "metadata": {},
- "outputs": [],
- "source": [
- "from common.dl.runner import ClassificationRunner\n",
- "\n",
- "runner = ClassificationRunner(\n",
- " model=model, \n",
- " criterion=criterion, \n",
- " optimizer=optimizer)\n",
- "runner.train_stage(\n",
- " loaders=loaders, \n",
- " callbacks=callbacks, \n",
- " logdir=logdir,\n",
- " epochs=n_epochs, verbose=True)"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "metadata": {},
- "outputs": [],
- "source": []
- }
- ],
- "metadata": {
- "kernelspec": {
- "display_name": "Python [conda env:py36]",
- "language": "python",
- "name": "conda-env-py36-py"
- },
- "language_info": {
- "codemirror_mode": {
- "name": "ipython",
- "version": 3
- },
- "file_extension": ".py",
- "mimetype": "text/x-python",
- "name": "python",
- "nbconvert_exporter": "python",
- "pygments_lexer": "ipython3",
- "version": "3.6.4"
- }
- },
- "nbformat": 4,
- "nbformat_minor": 2
- }
Add Comment
Please, Sign In to add comment