Advertisement
Guest User

Untitled

a guest
Feb 22nd, 2019
81
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
text 4.42 KB | None | 0 0
  1. {
  2. "cells": [
  3. {
  4. "cell_type": "code",
  5. "execution_count": null,
  6. "metadata": {
  7. "ExecuteTime": {
  8. "end_time": "2019-02-20T12:55:25.150873Z",
  9. "start_time": "2019-02-20T12:55:24.620811Z"
  10. }
  11. },
  12. "outputs": [],
  13. "source": [
  14. "import numpy as np\n",
  15. "from random import sample\n",
  16. "import matplotlib.pyplot as plt\n",
  17. "%matplotlib inline"
  18. ]
  19. },
  20. {
  21. "cell_type": "markdown",
  22. "metadata": {},
  23. "source": [
  24. "## exercise 1.4"
  25. ]
  26. },
  27. {
  28. "cell_type": "code",
  29. "execution_count": null,
  30. "metadata": {
  31. "ExecuteTime": {
  32. "end_time": "2019-02-20T12:55:26.077169Z",
  33. "start_time": "2019-02-20T12:55:26.067921Z"
  34. }
  35. },
  36. "outputs": [],
  37. "source": [
  38. "def fx(x): # target line\n",
  39. " return(-0.7 * x + 1)\n",
  40. "# build random data x_i in [0,1] with bias term\n",
  41. "xi = np.column_stack((np.ones(20), np.random.rand(20,2)))\n",
  42. "yi = (fx(xi[:,1]) > xi[:,2]) # T/F vector\n",
  43. "# xi"
  44. ]
  45. },
  46. {
  47. "cell_type": "code",
  48. "execution_count": null,
  49. "metadata": {
  50. "ExecuteTime": {
  51. "end_time": "2019-02-20T12:55:26.761259Z",
  52. "start_time": "2019-02-20T12:55:26.755219Z"
  53. }
  54. },
  55. "outputs": [],
  56. "source": [
  57. "yp = np.zeros(20) # build target vec in {-1, 1}\n",
  58. "yp[yi] = 1\n",
  59. "yp[np.logical_not(yi)] = -1"
  60. ]
  61. },
  62. {
  63. "cell_type": "code",
  64. "execution_count": null,
  65. "metadata": {
  66. "ExecuteTime": {
  67. "end_time": "2019-02-20T12:55:27.619844Z",
  68. "start_time": "2019-02-20T12:55:27.284779Z"
  69. }
  70. },
  71. "outputs": [],
  72. "source": [
  73. "# plot data with f(x)\n",
  74. "plt.scatter(xi[:,1], xi[:,2], c=yp)\n",
  75. "plt.plot(np.arange(0,1,.1), fx(np.arange(0,1,.1)))\n",
  76. "plt.savefig(\"aml_e1.4_fx.png\")"
  77. ]
  78. },
  79. {
  80. "cell_type": "code",
  81. "execution_count": null,
  82. "metadata": {
  83. "ExecuteTime": {
  84. "end_time": "2019-02-20T12:55:38.064284Z",
  85. "start_time": "2019-02-20T12:55:38.056503Z"
  86. },
  87. "code_folding": [
  88. 0
  89. ]
  90. },
  91. "outputs": [],
  92. "source": [
  93. "def fwd(x, y, w):\n",
  94. " ''' predict function \n",
  95. " in: x, y, current weights \n",
  96. " out: vector w/ True for misses '''\n",
  97. " yh = x @ w\n",
  98. " yh[yh > 0] = 1\n",
  99. " yh[yh < 0] = -1\n",
  100. " \n",
  101. " return(yh.T != y)"
  102. ]
  103. },
  104. {
  105. "cell_type": "code",
  106. "execution_count": null,
  107. "metadata": {
  108. "ExecuteTime": {
  109. "end_time": "2019-02-20T12:55:46.651852Z",
  110. "start_time": "2019-02-20T12:55:46.631850Z"
  111. }
  112. },
  113. "outputs": [],
  114. "source": [
  115. "w = np.random.rand(xi.shape[1],1) * 2 - 1 #random init weights [-1,1]\n",
  116. "neq = fwd(xi, yp, w)\n",
  117. "while np.sum(neq): # while some misses\n",
  118. " # randomly choose a datapoint where x misclassified\n",
  119. " ud = np.random.choice(np.where(neq==True)[1], 1) \n",
  120. " w = w + yp[ud] * xi[ud].T # update w\n",
  121. " neq = fwd(xi, yp, w) # get new miss vector\n",
  122. " print(np.sum(neq)) # show improvement\n",
  123. "print(w) # final weights"
  124. ]
  125. },
  126. {
  127. "cell_type": "code",
  128. "execution_count": null,
  129. "metadata": {
  130. "ExecuteTime": {
  131. "end_time": "2019-02-20T12:59:46.858343Z",
  132. "start_time": "2019-02-20T12:59:46.850840Z"
  133. }
  134. },
  135. "outputs": [],
  136. "source": [
  137. "print(w[1] / w[2]) # get slope\n",
  138. "print(w[0] / w[2]) # get bias"
  139. ]
  140. },
  141. {
  142. "cell_type": "code",
  143. "execution_count": null,
  144. "metadata": {
  145. "ExecuteTime": {
  146. "end_time": "2019-02-20T13:00:51.666033Z",
  147. "start_time": "2019-02-20T13:00:51.661232Z"
  148. }
  149. },
  150. "outputs": [],
  151. "source": [
  152. "## b + w1x1 + w2x2 = 0 ##\n",
  153. "wfxn = -((np.arange(0,1,.1) * (w[1])) + w[0])/w[2]"
  154. ]
  155. },
  156. {
  157. "cell_type": "code",
  158. "execution_count": null,
  159. "metadata": {
  160. "ExecuteTime": {
  161. "end_time": "2019-02-20T13:00:52.706379Z",
  162. "start_time": "2019-02-20T13:00:52.433112Z"
  163. }
  164. },
  165. "outputs": [],
  166. "source": [
  167. "\n",
  168. "plt.scatter(x=xi[:,1], y = xi[:,2], c=yp)\n",
  169. "plt.plot(np.arange(0,1,.1), wfxn)\n",
  170. "plt.savefig(\"aml_e1.4_gx.png\")"
  171. ]
  172. }
  173. ],
  174. "metadata": {
  175. "kernelspec": {
  176. "display_name": "Python 3",
  177. "language": "python",
  178. "name": "python3"
  179. },
  180. "language_info": {
  181. "codemirror_mode": {
  182. "name": "ipython",
  183. "version": 3
  184. },
  185. "file_extension": ".py",
  186. "mimetype": "text/x-python",
  187. "name": "python",
  188. "nbconvert_exporter": "python",
  189. "pygments_lexer": "ipython3",
  190. "version": "3.6.7"
  191. }
  192. },
  193. "nbformat": 4,
  194. "nbformat_minor": 2
  195. }
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement