Advertisement
Guest User

Untitled

a guest
Nov 17th, 2019
76
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
text 2.10 KB | None | 0 0
  1. import numpy as np
  2. import pandas as pd
  3. import matplotlib.pyplot as plt
  4. from matplotlib.colors import ListedColormap
  5.  
  6.  
  7. class Perceptron():
  8.  
  9. def __init__(self, eta=0.01, n_iter=50, random_state=1):
  10. self.eta = eta
  11. self.n_iter = n_iter
  12. self.random_state = random_state
  13.  
  14. def fit(self, X, y):
  15.  
  16. rgen = np.random.RandomState(self.random_state)
  17. self.w_ = rgen.normal(loc=0.0, scale=0.01, size=1 + X.shape[1])
  18. self.errors_ = []
  19.  
  20. for _ in range(self.n_iter):
  21. errors = 0
  22. for xi, target in zip(X, y):
  23. update = self.eta * (target - self.predict(xi))
  24. self.w_[1:] += update * xi
  25. self.w_[0] += update
  26. errors += int(update != 0.0)
  27. print(self.w_)
  28. self.errors_.append(errors)
  29. return self
  30.  
  31. def net_input(self, X):
  32. """Calculate net input"""
  33. return np.dot(X, self.w_[1:]) + self.w_[0]
  34.  
  35. def predict(self, X):
  36. """Return class label after unit step"""
  37. return np.where(self.net_input(X) >= 0.0, 1, -1)
  38.  
  39.  
  40.  
  41. # ### Reading-in the Iris data
  42.  
  43.  
  44. df = pd.read_csv('https://archive.ics.uci.edu/ml/machine-learning-databases/iris/iris.data', header=None)
  45. df.tail()
  46.  
  47. # ### Plotting the Iris data
  48.  
  49.  
  50.  
  51.  
  52. # select setosa and versicolor
  53. y = df.iloc[0:100, 4].values
  54. y = np.where(y == 'Iris-setosa', -1, 1)
  55.  
  56. # extract sepal length and petal length
  57. X = df.iloc[0:100, [0, 2]].values
  58.  
  59. # plot data
  60. plt.scatter(X[:50, 0], X[:50, 1],color='red', marker='o', label='setosa')
  61. plt.scatter(X[50:100, 0], X[50:100, 1],color='blue', marker='x', label='versicolor')
  62.  
  63. plt.xlabel('sepal length [cm]')
  64. plt.ylabel('petal length [cm]')
  65. plt.legend(loc='upper left')
  66.  
  67. # plt.savefig('images/02_06.png', dpi=300)
  68. plt.show()
  69.  
  70.  
  71.  
  72. # ### Training the perceptron model
  73.  
  74.  
  75. ppn = Perceptron(eta=0.1, n_iter=10)
  76.  
  77. ppn.fit(X, y)
  78.  
  79. plt.plot(range(1, len(ppn.errors_) + 1), ppn.errors_, marker='o')
  80. plt.xlabel('Epochs')
  81. plt.ylabel('Number of updates')
  82.  
  83. # plt.savefig('images/02_07.png', dpi=300)
  84. plt.show()
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement