Advertisement
Guest User

Untitled

a guest
Jun 24th, 2019
75
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
text 2.16 KB | None | 0 0
  1. class MySupportVectorMachine:
  2.  
  3. def __init__(self):
  4. self.w_vector = []
  5. self.bias = 0
  6. self.support_ = []
  7.  
  8. #---------------------------------------------------
  9. def get_support_vector_indices(self,Xdata, Ydata):
  10.  
  11.  
  12. #pass
  13. #---------------------------------------------------
  14. def fit(self,Xdata,Ydata):
  15. #Initialize our SVMs weight vector with zeros
  16. w = np.zeros(len(Xdata[0]))
  17. #The learning rate
  18. eta = 0.0001
  19. #how many iterations to train for
  20. epochs = 1000
  21. #store misclassifications so we can plot how they change over time
  22. bias = 0
  23. C = 0.001
  24. print("nnTraining... ",end ="" )
  25. #training part, gradient descent part
  26. for epoch in range(1,epochs):
  27. if epoch %1000 == 0:
  28. print("." , end = "")
  29. for i, x in enumerate(Xdata):
  30. #misclassification
  31. if Ydata[i] == 0:
  32. Ydata[i]= -1
  33.  
  34. if (Ydata[i]*np.dot(Xdata[i], w) + bias) < 1:
  35. #misclassified update for ours weights
  36. w = w + eta * ( C *(Xdata[i] * Ydata[i]) + (-2 *(1/epoch)* w) )
  37. bias = bias + eta * (Ydata[i] * C)
  38. else:
  39. #correct classification, update our weights
  40. w = w + eta * (-2 *(1/epoch)* w)
  41. if Ydata[i] == -1:
  42. Ydata[i]= 0
  43. self.w_vector = w
  44. self.bias = bias
  45. print("ntttSVM is Successfully Trained...n")
  46.  
  47.  
  48.  
  49. self.get_support_vector_indices(Xdata, Ydata)
  50.  
  51. #---------------------------------------------------
  52. def predict(self, XtestData):
  53. #print("nn w_vector: n",self.w_vector)
  54. #print("nn bias: n",self.bias)
  55. predictedDataList = []
  56. for xi in XtestData:
  57. predictedData = np.sign(np.dot(self.w_vector,xi) + self.bias).astype(int)
  58. if predictedData == -1:
  59. predictedData = 0
  60. predictedDataList.append(predictedData)
  61. return(np.array(predictedDataList))
  62.  
  63. #---------------------------------------------------
  64.  
  65. # End of the Class
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement