Guest User

Untitled

a guest
Nov 20th, 2018
94
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
text 3.55 KB | None | 0 0
  1. {
  2. "cells": [
  3. {
  4. "cell_type": "code",
  5. "execution_count": 0,
  6. "metadata": {
  7. "collapsed": false
  8. },
  9. "outputs": [],
  10. "source": [
  11. "import torch\n",
  12. "import torch.nn as nn\n",
  13. "import torch.optim as optim\n",
  14. "import torch.nn.functional as F"
  15. ]
  16. },{
  17. "cell_type": "code",
  18. "execution_count": 33,
  19. "metadata": {
  20. "collapsed": false
  21. },
  22. "outputs": [],
  23. "source": [
  24. "import torch
  25. import torch.nn as nn
  26. import torch.nn.functional as F
  27. import torch.optim as optim
  28. import numpy as np
  29.  
  30. class Testen(nn.Module): \n",
  31. " def __init__(self): \n",
  32. " super(Testen, self).__init__() \n",
  33. " self.layer_1 = torch.nn.Linear(20, 20)\n",
  34. " self.layer_2 = torch.nn.Linear(30, 30)\n",
  35. " self.layer_3 = torch.nn.Linear(30, 50)\n",
  36. "\n",
  37. " def forward(self, x):\n",
  38. " x = x.view(-1, 3072)\n",
  39. " x = F.relu(self.layer_1(x))\n",
  40. " x = F.relu(self.layer_2(x))\n",
  41. " x = F.relu(self.layer_3(x))\n",
  42. " return x\n"
  43. ]
  44. },{
  45. "cell_type": "code",
  46. "execution_count": null,
  47. "metadata": {},
  48. "outputs": [],
  49. "source": [
  50. "net = Rompear()"
  51. ]
  52. },
  53. {
  54. "cell_type": "code",
  55. "execution_count": null,
  56. "metadata": {
  57. "collapsed": true
  58. },
  59. "outputs": [],
  60. "source": [
  61. "criterion = nn.CrossEntropyLoss() \n",
  62. "optimizer = optim.Adam(net.parameters(), lr = 0.0001)"
  63. ]
  64. },
  65. {
  66. "cell_type": "code",
  67. "execution_count": null,
  68. "metadata": {},
  69. "outputs": [],
  70. "source": [
  71. "device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')\n",
  72. "net.to(device)"
  73. ]
  74. },
  75. {
  76. "cell_type": "code",
  77. "execution_count": null,
  78. "metadata": {},
  79. "outputs": [],
  80. "source": [
  81. "##############################\n",
  82. "# REPLACE WITH YOUR OWN DATA #\n",
  83. "##############################\n",
  84. "x = torch.from_numpy(np.random.rand(1042,32,3,32,32)).type(dtype=torch.float).to(device) #(samples, channels, heigth, width) \n",
  85. "y = torch.from_numpy(np.random.randint(20,size=(1042, 32))).to(device) #(number_of_classes, samples)"
  86. ]
  87. },
  88. {
  89. "cell_type": "code",
  90. "execution_count": null,
  91. "metadata": {},
  92. "outputs": [],
  93. "source": [
  94. "for epoch in range(2): # loop over the dataset multiple times \n",
  95. " running_loss = 0.0 \n",
  96. " for i in range(0, x.shape[0]):\n",
  97. " # get the inputs \n",
  98. " inputs, labels = x[i], y[i]\n",
  99. "\n",
  100. " # zero the parameter gradients \n",
  101. " optimizer.zero_grad() \n",
  102. "\n",
  103. " # forward + backward + optimize \n",
  104. " outputs = net(inputs) \n",
  105. " loss = criterion(outputs, labels) \n",
  106. " loss.backward() \n",
  107. " optimizer.step() \n",
  108. "\n",
  109. " # print statistics \n",
  110. " running_loss += loss.item() \n",
  111. " if i % 200 == 199:\n",
  112. " print('[%d, %5d] loss: %.3f' % (epoch + 1, i + 1, running_loss / 200)) \n",
  113. " running_loss = 0.0 \n",
  114. "print('Finished Training')"
  115. ]
  116. },
  117. {
  118. "cell_type": "code",
  119. "execution_count": null,
  120. "metadata": {
  121. "collapsed": true
  122. },
  123. "outputs": [],
  124. "source": []
  125. }
  126. ],
  127. "metadata": {
  128. "kernelspec": {
  129. "display_name": "Python 3",
  130. "language": "python",
  131. "name": "python3"
  132. },
  133. "language_info": {
  134. "codemirror_mode": {
  135. "name": "ipython",
  136. "version": 3
  137. },
  138. "file_extension": ".py",
  139. "mimetype": "text/x-python",
  140. "name": "python",
  141. "nbconvert_exporter": "python",
  142. "pygments_lexer": "ipython3",
  143. "version": "3.5.4"
  144. }
  145. },
  146. "nbformat": 4,
  147. "nbformat_minor": 2
  148. }
Add Comment
Please, Sign In to add comment