Not a member of Pastebin yet?
Sign Up,
it unlocks many cool features!
- {
- "cells": [
- {
- "cell_type": "code",
- "execution_count": 1,
- "metadata": {},
- "outputs": [
- {
- "name": "stderr",
- "output_type": "stream",
- "text": [
- "/home/tao/anaconda3/lib/python3.6/site-packages/h5py/__init__.py:36: FutureWarning: Conversion of the second argument of issubdtype from `float` to `np.floating` is deprecated. In future, it will be treated as `np.float64 == np.dtype(float).type`.\n",
- " from ._conv import register_converters as _register_converters\n"
- ]
- }
- ],
- "source": [
- "import os\n",
- "import math\n",
- "import torch\n",
- "import torch.nn.functional as F\n",
- "import torch.optim as optim\n",
- "import torch.utils.data as data\n",
- "import numpy as np\n",
- "import h5py\n",
- "import cv2\n",
- "from torch.autograd import Variable\n",
- "from torchvision import datasets, transforms\n",
- "from glob import glob\n",
- "import argparse\n",
- "from model import CVAE\n",
- "\n",
- "\n",
- "class MpiSintel(data.Dataset):\n",
- " \n",
- " def __init__(self, root = '', dstype = 'final', transforms = None):\n",
- "\n",
- " flow_root = os.path.join(root, 'flow')\n",
- " image_root = os.path.join(root, dstype)\n",
- "\n",
- " file_list = sorted(glob(os.path.join(flow_root, '*/*.mat')))\n",
- "\n",
- " self.flow_list = []\n",
- " self.image_list = []\n",
- " self.transforms = transforms\n",
- "\n",
- " for file in file_list:\n",
- " \n",
- " fbase = file[len(flow_root)+1:]\n",
- " fprefix = fbase[:-8]\n",
- " fnum = int(fbase[-8:-4])\n",
- "\n",
- " img = os.path.join(image_root, fprefix + \"%04d\"%fnum + '.png')\n",
- "\n",
- " if not os.path.isfile(img) or not os.path.isfile(file):\n",
- " continue\n",
- "\n",
- " self.image_list += [img]\n",
- " self.flow_list += [file]\n",
- "\n",
- " self.size = len(self.image_list)\n",
- "\n",
- " assert (len(self.image_list) == len(self.flow_list))\n",
- "\n",
- " def __getitem__(self, index):\n",
- "\n",
- " index = index % self.size\n",
- "\n",
- " img = cv2.imread(self.image_list[index])\n",
- "\n",
- " mat = h5py.File(self.flow_list[index],'r')\n",
- " flow = mat.get('img') \n",
- " flow = np.array(flow)\n",
- " flow = np.transpose(flow)\n",
- " \n",
- " ce = cv2.resize(img, (224,224))\n",
- " cd = cv2.resize(img, (28,28))\n",
- " flow = cv2.resize(flow, (224,224))\n",
- " \n",
- " if self.transforms is not None:\n",
- " ce = self.transforms(ce)\n",
- " cd = self.transforms(cd)\n",
- " flow = self.transforms(flow)\n",
- "\n",
- " return flow, ce, cd\n",
- "\n",
- " def __len__(self):\n",
- " return self.size\n",
- "\n",
- "\n",
- "def KL_divergence(recon_x, x, mu, logvar):\n",
- " BCE = F.binary_cross_entropy(recon_x, x.view(-1, x.shape[0]*x.shape[1]*x.shape[2]), size_average=False)\n",
- " KLD = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp())\n",
- " return BCE + KLD\n",
- "\n",
- "def loss_function(recon_of_norm, of_norm, recon_m_x, m_x, recon_m_y, m_y, z, mu, logvar):\n",
- " of_loss = np.sum([math.sqrt( (of_norm[i,j,0]-recon_of_norm[i,j,0])**2 + (of_norm[i,j,1]-recon_of_norm[i,j,1])**2 ) \n",
- " for i in range(of_norm.shape[0]) for j in range(of_norm.shape[1])])\n",
- " \n",
- " return of_loss + math.sqrt(recon_m_x-m_x) + math.sqrt(recon_m_y-m_y) + KL_divergence(recon_z, z, mu, logvar)\n",
- "\n",
- "\n",
- "\n",
- "def train(model, optimizer, epoch, train_loader):\n",
- " model.train()\n",
- " train_loss = 0\n",
- " for batch_idx, (flow, ce, cd) in enumerate(train_loader):\n",
- "\n",
- " flow, ce, cd = Variable(flow).cuda(), Variable(ce).cuda(), Variable(cd).cuda()\n",
- " optimizer.zero_grad()\n",
- " \n",
- " #recon_batch, mu, logvar = model(flow, ce, cd)\n",
- " flow_out, z, mu, logvar = model(flow, ce, cd)\n",
- "\n",
- " #normalize\n",
- " flow_out_norm = flow_out / np.linalg.norm(flow_out)\n",
- " flow_norm = flow / np.linalg.norm(flow)\n",
- " \n",
- " loss = loss_function(flow_out_norm, flow_norm, np.mean(flow_out[:,:,0]), np.mean(flow[:,:,0]),np.mean(flow_out[:,:,1]), np.mean(flow[:,:,1]),z, mu, logvar)\n",
- " loss.backward()\n",
- " train_loss += loss.flow[0]\n",
- " optimizer.step()\n",
- " if batch_idx % args.log_interval == 0:\n",
- " print('Train Epoch: {} [{}/{} (args{:.0f}%)]\\tLoss: {:.6f}'.format(\n",
- " epoch, batch_idx * len(flow), len(train_loader.dataset),\n",
- " 100. * batch_idx / len(train_loader),\n",
- " loss.flow[0] / len(flow)))\n",
- "\n",
- " print('====> Epoch: {} Average loss: {:.4f}'.format(\n",
- " epoch, train_loss / len(train_loader.dataset)))\n",
- "\n",
- "\n",
- "def test(model, test_loader):\n",
- " model.eval()\n",
- " test_loss = 0\n",
- " for i, (flow, ce, cd) in enumerate(test_loader):\n",
- " flow, ce, cd = Variable(flow, volatile=True).cuda(), Variable(ce).cuda(), Variable(cd).cuda()\n",
- " flow_out,z, mu, logvar = model(data)\n",
- " \n",
- " #normalize\n",
- " flow_out_norm = flow_out / np.linalg.norm(flow_out)\n",
- " flow_norm = flow / np.linalg.norm(flow)\n",
- " \n",
- " test_loss += loss_function(flow_out_norm, flow_norm, np.mean(flow_out[:,:,0]), np.mean(flow[:,:,0]),np.mean(flow_out[:,:,1]), np.mean(flow[:,:,1]),z, mu, logvar)\n",
- " if i == 0:\n",
- " n = min(flow.size(0), 8)\n",
- " comparison = torch.cat([flow[:n], flow_out.view(args.batch_size, 1, 28, 28)[:n]])\n",
- "\n",
- " test_loss /= len(test_loader.dataset)\n",
- " print('====> Test set loss: {:.4f}'.format(test_loss))\n"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "metadata": {},
- "outputs": [],
- "source": [
- "if __name__ == '__main__':\n",
- "\n",
- " # Set up training settings from command line options, or use default\n",
- " parser = argparse.ArgumentParser(description='PyTorch MNIST Example')\n",
- " parser.add_argument('--batch-size', type=int, default=7, metavar='N',\n",
- " help='input batch size for training (default: 7)')\n",
- " parser.add_argument('--test-batch-size', type=int, default=7, metavar='N',\n",
- " help='input batch size for testing (default: 7)')\n",
- " parser.add_argument('--epochs', type=int, default=10, metavar='N',\n",
- " help='number of epochs to train (default: 10)')\n",
- " parser.add_argument('--lr', type=float, default=0.001, metavar='LR',\n",
- " help='learning rate (default: 0.001)')\n",
- " parser.add_argument('--momentum', type=float, default=0.9, metavar='M',\n",
- " help='SGD momentum (default: 0.9)')\n",
- " parser.add_argument('--log-interval', type=int, default=10, metavar='N',\n",
- " help='how many batches to wait before logging training status')\n",
- " \n",
- " import sys; sys.argv=['']; del sys\n",
- " \n",
- " args = parser.parse_args()\n",
- "\n",
- " # Instantiate the model\n",
- " model = CVAE().cuda()\n",
- "\n",
- " # Choose SGD as the optimizer, initialize it with the parameters & settings\n",
- " optimizer = optim.Adam(model.parameters(), lr=1e-3)\n",
- "\n",
- " # Load data\n",
- " transformations = transforms.Compose([transforms.ToTensor()])"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "metadata": {},
- "outputs": [],
- "source": [
- " Mpi_Sintel_train = MpiSintel('./training', 'final', transformations)\n",
- " Mpi_Sintel_test = MpiSintel('./training', 'final', transformations)\n",
- "\n",
- " train_loader = torch.utils.data.DataLoader(dataset=Mpi_Sintel_train, batch_size=args.batch_size, shuffle=True)\n",
- " test_loader = torch.utils.data.DataLoader(dataset=Mpi_Sintel_test, batch_size=args.test_batch_size, shuffle=True)\n",
- "\n",
- " # Train & test the model\n",
- " for epoch in range(1, args.epochs + 1):\n",
- " train(model, optimizer, epoch, train_loader)\n",
- " test(model, test_loader)\n",
- "\n",
- "\n",
- " # Save the model for future use\n",
- " package_dir = os.path.dirname(os.path.abspath(__file__))\n",
- " model_path = os.path.join(package_dir,'model')\n",
- " torch.save(model.state_dict(), model_path)"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "metadata": {},
- "outputs": [],
- "source": []
- }
- ],
- "metadata": {
- "kernelspec": {
- "display_name": "Python 3",
- "language": "python",
- "name": "python3"
- },
- "language_info": {
- "codemirror_mode": {
- "name": "ipython",
- "version": 3
- },
- "file_extension": ".py",
- "mimetype": "text/x-python",
- "name": "python",
- "nbconvert_exporter": "python",
- "pygments_lexer": "ipython3",
- "version": "3.6.4"
- }
- },
- "nbformat": 4,
- "nbformat_minor": 2
- }
Add Comment
Please, Sign In to add comment