Guest User

Untitled

a guest
Apr 22nd, 2018
117
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
text 9.36 KB | None | 0 0
  1. {
  2. "cells": [
  3. {
  4. "cell_type": "code",
  5. "execution_count": 1,
  6. "metadata": {},
  7. "outputs": [
  8. {
  9. "name": "stderr",
  10. "output_type": "stream",
  11. "text": [
  12. "/home/tao/anaconda3/lib/python3.6/site-packages/h5py/__init__.py:36: FutureWarning: Conversion of the second argument of issubdtype from `float` to `np.floating` is deprecated. In future, it will be treated as `np.float64 == np.dtype(float).type`.\n",
  13. " from ._conv import register_converters as _register_converters\n"
  14. ]
  15. }
  16. ],
  17. "source": [
  18. "import os\n",
  19. "import math\n",
  20. "import torch\n",
  21. "import torch.nn.functional as F\n",
  22. "import torch.optim as optim\n",
  23. "import torch.utils.data as data\n",
  24. "import numpy as np\n",
  25. "import h5py\n",
  26. "import cv2\n",
  27. "from torch.autograd import Variable\n",
  28. "from torchvision import datasets, transforms\n",
  29. "from glob import glob\n",
  30. "import argparse\n",
  31. "from model import CVAE\n",
  32. "\n",
  33. "\n",
  34. "class MpiSintel(data.Dataset):\n",
  35. " \n",
  36. " def __init__(self, root = '', dstype = 'final', transforms = None):\n",
  37. "\n",
  38. " flow_root = os.path.join(root, 'flow')\n",
  39. " image_root = os.path.join(root, dstype)\n",
  40. "\n",
  41. " file_list = sorted(glob(os.path.join(flow_root, '*/*.mat')))\n",
  42. "\n",
  43. " self.flow_list = []\n",
  44. " self.image_list = []\n",
  45. " self.transforms = transforms\n",
  46. "\n",
  47. " for file in file_list:\n",
  48. " \n",
  49. " fbase = file[len(flow_root)+1:]\n",
  50. " fprefix = fbase[:-8]\n",
  51. " fnum = int(fbase[-8:-4])\n",
  52. "\n",
  53. " img = os.path.join(image_root, fprefix + \"%04d\"%fnum + '.png')\n",
  54. "\n",
  55. " if not os.path.isfile(img) or not os.path.isfile(file):\n",
  56. " continue\n",
  57. "\n",
  58. " self.image_list += [img]\n",
  59. " self.flow_list += [file]\n",
  60. "\n",
  61. " self.size = len(self.image_list)\n",
  62. "\n",
  63. " assert (len(self.image_list) == len(self.flow_list))\n",
  64. "\n",
  65. " def __getitem__(self, index):\n",
  66. "\n",
  67. " index = index % self.size\n",
  68. "\n",
  69. " img = cv2.imread(self.image_list[index])\n",
  70. "\n",
  71. " mat = h5py.File(self.flow_list[index],'r')\n",
  72. " flow = mat.get('img') \n",
  73. " flow = np.array(flow)\n",
  74. " flow = np.transpose(flow)\n",
  75. " \n",
  76. " ce = cv2.resize(img, (224,224))\n",
  77. " cd = cv2.resize(img, (28,28))\n",
  78. " flow = cv2.resize(flow, (224,224))\n",
  79. " \n",
  80. " if self.transforms is not None:\n",
  81. " ce = self.transforms(ce)\n",
  82. " cd = self.transforms(cd)\n",
  83. " flow = self.transforms(flow)\n",
  84. "\n",
  85. " return flow, ce, cd\n",
  86. "\n",
  87. " def __len__(self):\n",
  88. " return self.size\n",
  89. "\n",
  90. "\n",
  91. "def KL_divergence(recon_x, x, mu, logvar):\n",
  92. " BCE = F.binary_cross_entropy(recon_x, x.view(-1, x.shape[0]*x.shape[1]*x.shape[2]), size_average=False)\n",
  93. " KLD = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp())\n",
  94. " return BCE + KLD\n",
  95. "\n",
  96. "def loss_function(recon_of_norm, of_norm, recon_m_x, m_x, recon_m_y, m_y, z, mu, logvar):\n",
  97. " of_loss = np.sum([math.sqrt( (of_norm[i,j,0]-recon_of_norm[i,j,0])**2 + (of_norm[i,j,1]-recon_of_norm[i,j,1])**2 ) \n",
  98. " for i in range(of_norm.shape[0]) for j in range(of_norm.shape[1])])\n",
  99. " \n",
  100. " return of_loss + math.sqrt(recon_m_x-m_x) + math.sqrt(recon_m_y-m_y) + KL_divergence(recon_z, z, mu, logvar)\n",
  101. "\n",
  102. "\n",
  103. "\n",
  104. "def train(model, optimizer, epoch, train_loader):\n",
  105. " model.train()\n",
  106. " train_loss = 0\n",
  107. " for batch_idx, (flow, ce, cd) in enumerate(train_loader):\n",
  108. "\n",
  109. " flow, ce, cd = Variable(flow).cuda(), Variable(ce).cuda(), Variable(cd).cuda()\n",
  110. " optimizer.zero_grad()\n",
  111. " \n",
  112. " #recon_batch, mu, logvar = model(flow, ce, cd)\n",
  113. " flow_out, z, mu, logvar = model(flow, ce, cd)\n",
  114. "\n",
  115. " #normalize\n",
  116. " flow_out_norm = flow_out / np.linalg.norm(flow_out)\n",
  117. " flow_norm = flow / np.linalg.norm(flow)\n",
  118. " \n",
  119. " loss = loss_function(flow_out_norm, flow_norm, np.mean(flow_out[:,:,0]), np.mean(flow[:,:,0]),np.mean(flow_out[:,:,1]), np.mean(flow[:,:,1]),z, mu, logvar)\n",
  120. " loss.backward()\n",
  121. " train_loss += loss.flow[0]\n",
  122. " optimizer.step()\n",
  123. " if batch_idx % args.log_interval == 0:\n",
  124. " print('Train Epoch: {} [{}/{} (args{:.0f}%)]\\tLoss: {:.6f}'.format(\n",
  125. " epoch, batch_idx * len(flow), len(train_loader.dataset),\n",
  126. " 100. * batch_idx / len(train_loader),\n",
  127. " loss.flow[0] / len(flow)))\n",
  128. "\n",
  129. " print('====> Epoch: {} Average loss: {:.4f}'.format(\n",
  130. " epoch, train_loss / len(train_loader.dataset)))\n",
  131. "\n",
  132. "\n",
  133. "def test(model, test_loader):\n",
  134. " model.eval()\n",
  135. " test_loss = 0\n",
  136. " for i, (flow, ce, cd) in enumerate(test_loader):\n",
  137. " flow, ce, cd = Variable(flow, volatile=True).cuda(), Variable(ce).cuda(), Variable(cd).cuda()\n",
  138. " flow_out,z, mu, logvar = model(data)\n",
  139. " \n",
  140. " #normalize\n",
  141. " flow_out_norm = flow_out / np.linalg.norm(flow_out)\n",
  142. " flow_norm = flow / np.linalg.norm(flow)\n",
  143. " \n",
  144. " test_loss += loss_function(flow_out_norm, flow_norm, np.mean(flow_out[:,:,0]), np.mean(flow[:,:,0]),np.mean(flow_out[:,:,1]), np.mean(flow[:,:,1]),z, mu, logvar)\n",
  145. " if i == 0:\n",
  146. " n = min(flow.size(0), 8)\n",
  147. " comparison = torch.cat([flow[:n], flow_out.view(args.batch_size, 1, 28, 28)[:n]])\n",
  148. "\n",
  149. " test_loss /= len(test_loader.dataset)\n",
  150. " print('====> Test set loss: {:.4f}'.format(test_loss))\n"
  151. ]
  152. },
  153. {
  154. "cell_type": "code",
  155. "execution_count": null,
  156. "metadata": {},
  157. "outputs": [],
  158. "source": [
  159. "if __name__ == '__main__':\n",
  160. "\n",
  161. " # Set up training settings from command line options, or use default\n",
  162. " parser = argparse.ArgumentParser(description='PyTorch MNIST Example')\n",
  163. " parser.add_argument('--batch-size', type=int, default=7, metavar='N',\n",
  164. " help='input batch size for training (default: 7)')\n",
  165. " parser.add_argument('--test-batch-size', type=int, default=7, metavar='N',\n",
  166. " help='input batch size for testing (default: 7)')\n",
  167. " parser.add_argument('--epochs', type=int, default=10, metavar='N',\n",
  168. " help='number of epochs to train (default: 10)')\n",
  169. " parser.add_argument('--lr', type=float, default=0.001, metavar='LR',\n",
  170. " help='learning rate (default: 0.001)')\n",
  171. " parser.add_argument('--momentum', type=float, default=0.9, metavar='M',\n",
  172. " help='SGD momentum (default: 0.9)')\n",
  173. " parser.add_argument('--log-interval', type=int, default=10, metavar='N',\n",
  174. " help='how many batches to wait before logging training status')\n",
  175. " \n",
  176. " import sys; sys.argv=['']; del sys\n",
  177. " \n",
  178. " args = parser.parse_args()\n",
  179. "\n",
  180. " # Instantiate the model\n",
  181. " model = CVAE().cuda()\n",
  182. "\n",
  183. " # Choose SGD as the optimizer, initialize it with the parameters & settings\n",
  184. " optimizer = optim.Adam(model.parameters(), lr=1e-3)\n",
  185. "\n",
  186. " # Load data\n",
  187. " transformations = transforms.Compose([transforms.ToTensor()])"
  188. ]
  189. },
  190. {
  191. "cell_type": "code",
  192. "execution_count": null,
  193. "metadata": {},
  194. "outputs": [],
  195. "source": [
  196. " Mpi_Sintel_train = MpiSintel('./training', 'final', transformations)\n",
  197. " Mpi_Sintel_test = MpiSintel('./training', 'final', transformations)\n",
  198. "\n",
  199. " train_loader = torch.utils.data.DataLoader(dataset=Mpi_Sintel_train, batch_size=args.batch_size, shuffle=True)\n",
  200. " test_loader = torch.utils.data.DataLoader(dataset=Mpi_Sintel_test, batch_size=args.test_batch_size, shuffle=True)\n",
  201. "\n",
  202. " # Train & test the model\n",
  203. " for epoch in range(1, args.epochs + 1):\n",
  204. " train(model, optimizer, epoch, train_loader)\n",
  205. " test(model, test_loader)\n",
  206. "\n",
  207. "\n",
  208. " # Save the model for future use\n",
  209. " package_dir = os.path.dirname(os.path.abspath(__file__))\n",
  210. " model_path = os.path.join(package_dir,'model')\n",
  211. " torch.save(model.state_dict(), model_path)"
  212. ]
  213. },
  214. {
  215. "cell_type": "code",
  216. "execution_count": null,
  217. "metadata": {},
  218. "outputs": [],
  219. "source": []
  220. }
  221. ],
  222. "metadata": {
  223. "kernelspec": {
  224. "display_name": "Python 3",
  225. "language": "python",
  226. "name": "python3"
  227. },
  228. "language_info": {
  229. "codemirror_mode": {
  230. "name": "ipython",
  231. "version": 3
  232. },
  233. "file_extension": ".py",
  234. "mimetype": "text/x-python",
  235. "name": "python",
  236. "nbconvert_exporter": "python",
  237. "pygments_lexer": "ipython3",
  238. "version": "3.6.4"
  239. }
  240. },
  241. "nbformat": 4,
  242. "nbformat_minor": 2
  243. }
Add Comment
Please, Sign In to add comment