Guest User

Untitled

a guest
Dec 16th, 2018
126
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
text 1.75 KB | None | 0 0
  1. import numpy as np
  2. from scipy.optimize import minimize
  3. from matplotlib import pyplot as plt
  4.  
  5. ONE_DATA = 30
  6. x = np.vstack((np.random.normal(loc=-2, scale=1.0, size=(ONE_DATA, 2)),\
  7. np.random.normal(loc=2, scale=1.0, size=(ONE_DATA, 2))))
  8. y = np.hstack((np.repeat(-1, ONE_DATA), np.repeat(1, ONE_DATA)))
  9.  
  10. def get_obj(x, y):
  11. def obj(alpha):
  12. second = np.array([alpha[i]*alpha[j]*y[i]*y[j]*np.dot(x[i], x[j]) \
  13. for i in range(len(alpha)) for j in range(len(alpha))])
  14. return - np.sum(alpha) + .5 * np.sum(second)
  15. return obj
  16.  
  17. def get_g(i):
  18. def g(alpha):
  19. return alpha[i]
  20. return g
  21.  
  22. def get_h(y):
  23. def h(alpha):
  24. return np.sum([alpha[i]*y[i] for i in range(len(alpha))])
  25. return h
  26.  
  27. cons = [
  28. {'type': 'eq', 'fun': get_h(y)}
  29. ]
  30. for i in range(len(x)):
  31. cons.append({'type': 'ineq', 'fun': get_g(i)})
  32.  
  33. res = minimize(get_obj(x, y), np.zeros(len(x)), constraints=cons, method="SLSQP")
  34. alpha_hat = res.x
  35. tol = 1e-10
  36. support_index = np.where(alpha_hat > tol)[0]
  37. non_suppoert_index = np.where(alpha_hat <= tol)[0]
  38. alpha_hat[non_suppoert_index] = 0
  39. w = np.sum([alpha_hat[i]*y[i]*x[i] for i in range(len(x))], axis=0)
  40. x_m = x[support_index][np.where(y[support_index]==-1)]
  41. x_p = x[support_index][np.where(y[support_index]==1)]
  42. b = - ( np.dot(w, x_p[0]) + np.dot(w, x_m[0]) ) / 2
  43.  
  44. def classifier(w1, w2, b, x):
  45. return -w1/w2*x - b/w2
  46.  
  47. for xi, yi in zip(x, y):
  48. if (xi[0] in x_m and xi[1] in x_m) or \
  49. (xi[0] in x_p and xi[1] in x_p):
  50. plt.plot(xi[0], xi[1], 'gv')
  51. continue
  52. if yi == 1:
  53. plt.plot(xi[0], xi[1], 'ro')
  54. else:
  55. plt.plot(xi[0], xi[1], 'bx')
  56. x = np.linspace(-5, 5, 100)
  57. y = classifier(w[0], w[1], b, x)
  58. plt.plot(x, y)
  59. plt.xlim([-5, 5])
  60. plt.ylim([-5, 5])
  61. plt.show()
Add Comment
Please, Sign In to add comment