Advertisement
Guest User

Untitled

a guest
Nov 21st, 2019
118
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
Python 5.43 KB | None | 0 0
  1. import numpy as np
  2. from collections import Counter
  3.  
  4. # я честно не знаю, почему код в прошлой задаче работал, а в этой задаче не работает
  5. # поэтому использую маску, чтобы брать векторы, которые не константные
  6.  
  7. def find_best_split(feature_vector, target_vector):
  8.     sorted_inds = np.argsort(feature_vector)
  9.     feature_sorted = np.array(feature_vector)[sorted_inds]
  10.     mask = feature_sorted[1:] != feature_sorted[:-1]
  11.     target_sorted = np.array(target_vector)[sorted_inds]
  12.     thresholds = ((feature_sorted[1:] + feature_sorted[:-1]) / 2)[mask]
  13.  
  14.     sz = len(target_sorted)
  15.     l_sz = np.arange(1, sz)
  16.     l_1 = np.cumsum(target_sorted)
  17.     p_l_1 = l_1[:-1] / l_sz
  18.     p_l_0 = 1 - p_l_1
  19.     h_l = 1 - p_l_0**2 - p_l_1**2
  20.    
  21.  
  22.     r_sz = sz - l_sz
  23.     r_1 = l_1[-1] - l_1[:-1]
  24.     p_r_1 = r_1 / (sz - l_sz)
  25.     p_r_0 = 1 - p_r_1
  26.     h_r = 1 - p_r_0**2 - p_r_1**2
  27.  
  28.     ginis = (-(h_l) * l_sz / sz  -(h_r) * r_sz / sz)[mask]
  29.  
  30.     best = np.argmax(ginis)
  31.     threshold_best = thresholds[best]
  32.     gini_best = ginis[best]
  33.  
  34.     return thresholds, ginis, threshold_best, gini_best
  35.  
  36. class DecisionTree:
  37.     def __init__(self, feature_types, max_depth=None, min_samples_split=None, min_samples_leaf=None):
  38.         if np.any(list(map(lambda x: x != "real" and x != "categorical", feature_types))):
  39.             raise ValueError("There is unknown feature type")
  40.  
  41.         self._tree = {}
  42.         self._feature_types = feature_types
  43.         self._max_depth = max_depth
  44.         self._min_samples_split = min_samples_split
  45.         self._min_samples_leaf = min_samples_leaf
  46.  
  47.     def _fit_node(self, sub_X, sub_y, node):
  48.         if np.all(sub_y == sub_y[0]):
  49.             node["type"] = "terminal"
  50.             node["class"] = sub_y[0]
  51.             return
  52.  
  53.         feature_best, threshold_best, gini_best, split = None, None, None, None
  54.         for feature in range(sub_X.shape[1]):
  55.             if len(np.unique(sub_X[:, feature])) == 1:
  56.                 continue
  57.            
  58.             feature_type = self._feature_types[feature]
  59.             categories_map = {}
  60.  
  61.             if feature_type == "real":
  62.                 feature_vector = sub_X[:, feature]
  63.             elif feature_type == "categorical":
  64.                 counts = Counter(sub_X[:, feature])
  65.                 clicks = Counter(sub_X[sub_y == 1, feature])
  66.                 ratio = {}
  67.                 for key, current_count in counts.items():
  68.                     if key in clicks:
  69.                         current_click = clicks[key]
  70.                     else:
  71.                         current_click = 0
  72.                     ratio[key] = current_click / current_count
  73.                
  74.                 sorted_categories = list(map(lambda x: x[0], sorted(ratio.items(), key=lambda x: x[1])))
  75.                
  76.                 categories_map = dict(zip(sorted_categories, list(range(len(sorted_categories)))))
  77.                                      
  78.                 feature_vector = np.array(list(map(lambda x: categories_map[x], sub_X[:, feature])))
  79.                
  80.             else:
  81.                 raise ValueError
  82.  
  83.             _, _, threshold, gini = find_best_split(feature_vector, sub_y)
  84.            
  85.             if gini_best is None or gini > gini_best:
  86.                 feature_best = feature
  87.                 gini_best = gini
  88.                 split = feature_vector < threshold
  89.  
  90.                 if feature_type == "real":
  91.                     threshold_best = threshold
  92.                 elif feature_type == "categorical":
  93.                     threshold_best = list(map(lambda x: x[0],
  94.                                               filter(lambda x: x[1] < threshold, categories_map.items())))
  95.                 else:
  96.                     raise ValueError
  97.  
  98.         if feature_best is None:
  99.             node["type"] = "terminal"
  100.             node["class"] = Counter(sub_y).most_common(1)[0][0]
  101.             return
  102.  
  103.         node["type"] = "nonterminal"
  104.  
  105.         node["feature_split"] = feature_best
  106.         if self._feature_types[feature_best] == "real":
  107.             node["threshold"] = threshold_best
  108.         elif self._feature_types[feature_best] == "categorical":
  109.             node["categories_split"] = threshold_best
  110.         else:
  111.             raise ValueError
  112.         node["left_child"], node["right_child"] = {}, {}
  113.         self._fit_node(sub_X[split], sub_y[split], node["left_child"])
  114.         self._fit_node(sub_X[np.logical_not(split)], sub_y[np.logical_not(split)], node["right_child"])
  115.  
  116.     def _predict_node(self, x, node):
  117.         if node["type"] == "terminal":
  118.             return node["class"]
  119.  
  120.         feature = node["feature_split"]
  121.         feature_type = self._feature_types[feature]
  122.  
  123.         if feature_type == "real":
  124.             if x[feature] < node["threshold"]:
  125.                 return self._predict_node(x, node["left_child"])
  126.             return self._predict_node(x, node["right_child"])
  127.         elif feature_type == "categorical":
  128.             if x[feature] in node["categories_split"]:
  129.                 return self._predict_node(x, node["left_child"])
  130.             return self._predict_node(x, node["right_child"])
  131.  
  132.     def fit(self, X, y):
  133.         self._fit_node(X, y, self._tree)
  134.  
  135.     def predict(self, X):
  136.         predicted = []
  137.         for x in X:
  138.             predicted.append(self._predict_node(x, self._tree))
  139.         return np.array(predicted)
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement