daily pastebin goal
9%
SHARE
TWEET

Untitled

a guest Dec 16th, 2018 67 Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
  1. import numpy as np
  2. from scipy.optimize import minimize
  3. from matplotlib import pyplot as plt
  4.  
  5. ONE_DATA = 30
  6. x = np.vstack((np.random.normal(loc=-2, scale=1.0, size=(ONE_DATA, 2)),\
  7.         np.random.normal(loc=2, scale=1.0, size=(ONE_DATA, 2))))
  8. y = np.hstack((np.repeat(-1, ONE_DATA), np.repeat(1, ONE_DATA)))
  9.  
  10. def get_obj(x, y):
  11.     def obj(alpha):
  12.         second = np.array([alpha[i]*alpha[j]*y[i]*y[j]*np.dot(x[i], x[j]) \
  13.                 for i in range(len(alpha)) for j in range(len(alpha))])
  14.         return  - np.sum(alpha) + .5 * np.sum(second)
  15.     return obj
  16.  
  17. def get_g(i):
  18.     def g(alpha):
  19.         return alpha[i]
  20.     return g
  21.  
  22. def get_h(y):
  23.     def h(alpha):
  24.         return np.sum([alpha[i]*y[i] for i in range(len(alpha))])
  25.     return h
  26.  
  27. cons = [
  28.     {'type': 'eq', 'fun': get_h(y)}
  29. ]
  30. for i in range(len(x)):
  31.     cons.append({'type': 'ineq', 'fun': get_g(i)})
  32.  
  33. res = minimize(get_obj(x, y), np.zeros(len(x)), constraints=cons, method="SLSQP")
  34. alpha_hat = res.x
  35. tol = 1e-10
  36. support_index = np.where(alpha_hat > tol)[0]
  37. non_suppoert_index = np.where(alpha_hat <= tol)[0]
  38. alpha_hat[non_suppoert_index] = 0
  39. w = np.sum([alpha_hat[i]*y[i]*x[i] for i in range(len(x))], axis=0)
  40. x_m = x[support_index][np.where(y[support_index]==-1)]
  41. x_p = x[support_index][np.where(y[support_index]==1)]
  42. b = - ( np.dot(w, x_p[0]) + np.dot(w, x_m[0]) ) / 2
  43.  
  44. def classifier(w1, w2, b, x):
  45.     return -w1/w2*x - b/w2
  46.  
  47. for xi, yi in zip(x, y):
  48.     if (xi[0] in x_m and xi[1] in x_m) or \
  49.             (xi[0] in x_p and xi[1] in x_p):
  50.         plt.plot(xi[0], xi[1], 'gv')
  51.         continue
  52.     if yi == 1:
  53.         plt.plot(xi[0], xi[1], 'ro')
  54.     else:
  55.         plt.plot(xi[0], xi[1], 'bx')
  56. x = np.linspace(-5, 5, 100)
  57. y = classifier(w[0], w[1], b, x)
  58. plt.plot(x, y)
  59. plt.xlim([-5, 5])
  60. plt.ylim([-5, 5])
  61. plt.show()
RAW Paste Data
We use cookies for various purposes including analytics. By continuing to use Pastebin, you agree to our use of cookies as described in the Cookies Policy. OK, I Understand
 
Top