Guest User

Untitled

a guest
Jul 19th, 2018
119
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
text 4.71 KB | None | 0 0
  1. import numpy as np
  2. import pandas as pd
  3.  
  4.  
  5. class CART_R(object):
  6. # node_min_num 停止划分的最小节点样本数
  7. # tol 停止划分的最小损失值
  8. def __init__(self, node_min_num, tol):
  9. self.node_min_num = node_min_num
  10. self.tol = tol
  11.  
  12. def fit(self, x_train, y_train):
  13. self.x_train = x_train
  14. self.y_train = y_train
  15. x_train = pd.DataFrame(x_train)
  16. self.Tree = self.__create_tree(x_train, y_train)
  17.  
  18. def predict(self,x_test):
  19. x_test = np.array(x_test)
  20. y_pre = []
  21. for x in x_test:
  22. y = self.__find_y(self.Tree, x)
  23. y_pre.append(y)
  24. return y_pre
  25.  
  26. def __find_y(self, tree, x):
  27. if isinstance(tree, dict):
  28. keys = list(tree.keys())
  29. j = keys[0][0]
  30. s = keys[0][1]
  31. if x[j] <= s:
  32. return self.__find_y(tree[(j, s, 'left')], x)
  33. else:
  34. return self.__find_y(tree[(j, s, 'right')], x)
  35. else:
  36. return tree
  37.  
  38. def __create_tree(self, x_train, y_train):
  39. # 如果节点样本数小于最小叶节样本数,停止树的增长,返回节点对应值
  40. if self.__stop_split_node_num(x_train.shape[0], self.node_min_num):
  41. return y_train.mean()
  42. else:
  43. # 如果没有特征可以划分了,停止划分
  44. if self.__j_is_none(x_train):
  45. return y_train.mean()
  46. else:
  47. result = self.__find_js(x_train, y_train, self.tol)
  48. # 如果返回的结果是'stop_split',表示损失值小于阈值,停止划分
  49. if result == 'stop_split':
  50. return y_train.mean()
  51. else:
  52. best_j, best_s = result
  53. l_x, l_y, r_x, r_y = self.__split_data(x_train, y_train, best_j, best_s)
  54. Tree = {(best_j, best_s, 'left'): self.__create_tree(l_x, l_y),
  55. (best_j, best_s, 'right'): self.__create_tree(r_x, r_y)}
  56. return Tree
  57.  
  58. # 当叶节点样本数量小于规定的最小值时,停止划分
  59. def __stop_split_node_num(self, node_num, min_num):
  60. if node_num <= min_num:
  61. return True
  62. else:
  63. return False
  64.  
  65. # 当寻找j,s时,算出的损失值低于阈值,停止划分
  66. def __stop_split_tol(self, js_values, tol):
  67. if js_values <= tol:
  68. return True
  69. else:
  70. return False
  71.  
  72. # 当没有特征可以用来划分时,停止划分
  73. def __j_is_none(self,x_train):
  74. for j in x_train.columns:
  75. if x_train[j].unique().size != 1:
  76. return False
  77. return True
  78.  
  79. # 寻找j,s
  80. def __find_js(self, x_train, y_train, tol):
  81. min_loss = np.inf
  82. best_j = 0
  83. best_s = 0
  84. for j in x_train.columns:
  85. # 找到特征j的所有可能值,并进行排序
  86. s_sort = x_train[j].unique()
  87. s_sort.sort()
  88. # 如果j只有一个值了,跳过该循环
  89. if s_sort.size == 1:
  90. continue
  91. else:
  92. for s in s_sort:
  93. # 算出js对应损失,以及左右单元的对应值
  94. loss, y_left, y_right = self.__js_loss(x_train, y_train, j, s)
  95. # 如果算出的损失低于阈值,停止循环
  96. if self.__stop_split_tol(loss, tol):
  97. return 'stop_split'
  98. else:
  99. if loss < min_loss:
  100. min_loss = loss
  101. best_j = j
  102. best_s = s
  103. return best_j, best_s
  104.  
  105. # 根据j,s算出损失
  106. def __js_loss(self, x_train, y_train, j, s):
  107. left = y_train[x_train[j] <= s]
  108. y_left = left.mean()
  109. loss_left = ((left - y_left)**2).sum()
  110. right = y_train[x_train[j] > s]
  111. y_right = right.mean()
  112. right_left = ((right - y_right) ** 2).sum()
  113. loss = loss_left+right_left
  114. return loss, y_left, y_right
  115.  
  116. # 按best_j和best_s进行划分
  117. def __split_data(self, x_train, y_train, best_j, best_s):
  118. # 左划分
  119. left_split = x_train[best_j] <= best_s
  120. l_x, l_y = x_train[left_split], y_train[left_split]
  121. # 右划分
  122. right_split = -left_split
  123. r_x, r_y = x_train[right_split], y_train[right_split]
  124. return l_x, l_y, r_x, r_y
  125.  
  126.  
  127. if __name__ == '__main__':
  128. from sklearn.datasets import load_boston
  129. from sklearn.metrics import mean_squared_error
  130. X, y = load_boston(True)
  131. mytree = CART_R(5, 0.01)
  132. mytree.fit(X, y)
  133. y_pre = mytree.predict(X)
  134. mse = mean_squared_error(y, y_pre)
  135. print('mse为%f' % mse)
Add Comment
Please, Sign In to add comment