# Untitled

a guest Dec 16th, 2018 67 Never
1. import numpy as np
2. from scipy.optimize import minimize
3. from matplotlib import pyplot as plt
4.
5. ONE_DATA = 30
6. x = np.vstack((np.random.normal(loc=-2, scale=1.0, size=(ONE_DATA, 2)),\
7.         np.random.normal(loc=2, scale=1.0, size=(ONE_DATA, 2))))
8. y = np.hstack((np.repeat(-1, ONE_DATA), np.repeat(1, ONE_DATA)))
9.
10. def get_obj(x, y):
11.     def obj(alpha):
12.         second = np.array([alpha[i]*alpha[j]*y[i]*y[j]*np.dot(x[i], x[j]) \
13.                 for i in range(len(alpha)) for j in range(len(alpha))])
14.         return  - np.sum(alpha) + .5 * np.sum(second)
15.     return obj
16.
17. def get_g(i):
18.     def g(alpha):
19.         return alpha[i]
20.     return g
21.
22. def get_h(y):
23.     def h(alpha):
24.         return np.sum([alpha[i]*y[i] for i in range(len(alpha))])
25.     return h
26.
27. cons = [
28.     {'type': 'eq', 'fun': get_h(y)}
29. ]
30. for i in range(len(x)):
31.     cons.append({'type': 'ineq', 'fun': get_g(i)})
32.
33. res = minimize(get_obj(x, y), np.zeros(len(x)), constraints=cons, method="SLSQP")
34. alpha_hat = res.x
35. tol = 1e-10
36. support_index = np.where(alpha_hat > tol)[0]
37. non_suppoert_index = np.where(alpha_hat <= tol)[0]
38. alpha_hat[non_suppoert_index] = 0
39. w = np.sum([alpha_hat[i]*y[i]*x[i] for i in range(len(x))], axis=0)
40. x_m = x[support_index][np.where(y[support_index]==-1)]
41. x_p = x[support_index][np.where(y[support_index]==1)]
42. b = - ( np.dot(w, x_p[0]) + np.dot(w, x_m[0]) ) / 2
43.
44. def classifier(w1, w2, b, x):
45.     return -w1/w2*x - b/w2
46.
47. for xi, yi in zip(x, y):
48.     if (xi[0] in x_m and xi[1] in x_m) or \
49.             (xi[0] in x_p and xi[1] in x_p):
50.         plt.plot(xi[0], xi[1], 'gv')
51.         continue
52.     if yi == 1:
53.         plt.plot(xi[0], xi[1], 'ro')
54.     else:
55.         plt.plot(xi[0], xi[1], 'bx')
56. x = np.linspace(-5, 5, 100)
57. y = classifier(w[0], w[1], b, x)
58. plt.plot(x, y)
59. plt.xlim([-5, 5])
60. plt.ylim([-5, 5])
61. plt.show()
