• API
• FAQ
• Tools
• Archive
daily pastebin goal
37%
SHARE
TWEET

# Untitled

a guest Dec 16th, 2018 67 Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
1. import numpy as np
2. from scipy.optimize import minimize
3. from matplotlib import pyplot as plt
4.
5. ONE_DATA = 30
6. x = np.vstack((np.random.normal(loc=-2, scale=1.0, size=(ONE_DATA, 2)),\
7.         np.random.normal(loc=2, scale=1.0, size=(ONE_DATA, 2))))
8. y = np.hstack((np.repeat(-1, ONE_DATA), np.repeat(1, ONE_DATA)))
9.
10. def get_obj(x, y):
11.     def obj(alpha):
12.         second = np.array([alpha[i]*alpha[j]*y[i]*y[j]*np.dot(x[i], x[j]) \
13.                 for i in range(len(alpha)) for j in range(len(alpha))])
14.         return  - np.sum(alpha) + .5 * np.sum(second)
15.     return obj
16.
17. def get_g(i):
18.     def g(alpha):
19.         return alpha[i]
20.     return g
21.
22. def get_h(y):
23.     def h(alpha):
24.         return np.sum([alpha[i]*y[i] for i in range(len(alpha))])
25.     return h
26.
27. cons = [
28.     {'type': 'eq', 'fun': get_h(y)}
29. ]
30. for i in range(len(x)):
31.     cons.append({'type': 'ineq', 'fun': get_g(i)})
32.
33. res = minimize(get_obj(x, y), np.zeros(len(x)), constraints=cons, method="SLSQP")
34. alpha_hat = res.x
35. tol = 1e-10
36. support_index = np.where(alpha_hat > tol)[0]
37. non_suppoert_index = np.where(alpha_hat <= tol)[0]
38. alpha_hat[non_suppoert_index] = 0
39. w = np.sum([alpha_hat[i]*y[i]*x[i] for i in range(len(x))], axis=0)
40. x_m = x[support_index][np.where(y[support_index]==-1)]
41. x_p = x[support_index][np.where(y[support_index]==1)]
42. b = - ( np.dot(w, x_p[0]) + np.dot(w, x_m[0]) ) / 2
43.
44. def classifier(w1, w2, b, x):
45.     return -w1/w2*x - b/w2
46.
47. for xi, yi in zip(x, y):
48.     if (xi[0] in x_m and xi[1] in x_m) or \
49.             (xi[0] in x_p and xi[1] in x_p):
50.         plt.plot(xi[0], xi[1], 'gv')
51.         continue
52.     if yi == 1:
53.         plt.plot(xi[0], xi[1], 'ro')
54.     else:
55.         plt.plot(xi[0], xi[1], 'bx')
56. x = np.linspace(-5, 5, 100)
57. y = classifier(w[0], w[1], b, x)
58. plt.plot(x, y)
59. plt.xlim([-5, 5])
60. plt.ylim([-5, 5])
61. plt.show()
RAW Paste Data
We use cookies for various purposes including analytics. By continuing to use Pastebin, you agree to our use of cookies as described in the Cookies Policy.

Top