sklearn DecisionTree 源碼分析

sklearn.tree._classes.BaseDecisionTree#fit
y至少爲1維(意思是可以處理multilabels數據)

y = np.atleast_1d(y)
if is_classifier(self):
    self.tree_ = Tree(self.n_features_,
                      self.n_classes_, self.n_outputs_)
else:
    self.tree_ = Tree(self.n_features_,
                      # TODO: tree should't need this in this case
                      np.array([1] * self.n_outputs_, dtype=np.intp),
                      self.n_outputs_)
self.n_outputs_ = y.shape[1]
self.n_classes_ = self.n_classes_[0]
self.n_classes_ = []
for k in range(self.n_outputs_):
    classes_k, y_encoded[:, k] = np.unique(y[:, k],
                                           return_inverse=True)
    self.classes_.append(classes_k)
    self.n_classes_.append(classes_k.shape[0])
np.unique([3,2,2,3,3,4], return_inverse=True)
Out[4]: (array([2, 3, 4]), array([1, 0, 0, 1, 1, 2]))

return_inverse類似於LabelEncode

sklearn.tree._tree.Tree

    def __cinit__(self, int n_features, np.ndarray[SIZE_t, ndim=1] n_classes,
                  int n_outputs):
  1. 特徵數
  2. 類別數
  3. label維度
# Use BestFirst if max_leaf_nodes given; use DepthFirst otherwise
if max_leaf_nodes < 0:
    builder = DepthFirstTreeBuilder(splitter, min_samples_split,
                                    min_samples_leaf,
                                    min_weight_leaf,
                                    max_depth,
                                    self.min_impurity_decrease,
                                    min_impurity_split)
else:
    builder = BestFirstTreeBuilder(splitter, min_samples_split,
                                   min_samples_leaf,
                                   min_weight_leaf,
                                   max_depth,
                                   max_leaf_nodes,
                                   self.min_impurity_decrease,
                                   min_impurity_split)

scikit-learn決策樹算法類庫介紹

最大葉子節點數max_leaf_nodes

通過限制最大葉子節點數,可以防止過擬合,默認是"None”,即不限制最大的葉子節點數。如果加了限制,算法會建立在最大葉子節點數內最優的決策樹。如果特徵不多,可以不考慮這個值,但是如果特徵分成多的話,可以加以限制,具體的值可以通過交叉驗證得到。

sklearn.tree._tree.DepthFirstTreeBuilder#build

builder.build(self.tree_, X, y, sample_weight, X_idx_sorted)
cpdef build(self, Tree tree, object X, np.ndarray y,
            np.ndarray sample_weight=None,
            np.ndarray X_idx_sorted=None):

注意到一個現象,這裏該有的參數都有,但是class_weight去哪了呢?懷疑是轉化了sample_weight

if self.class_weight is not None:
    expanded_class_weight = compute_sample_weight(
        self.class_weight, y_original)
if expanded_class_weight is not None:
    if sample_weight is not None:
        sample_weight = sample_weight * expanded_class_weight
    else:
        sample_weight = expanded_class_weight

sklearn/tree/_tree.pyx:203

splitter.init(X, y, sample_weight_ptr, X_idx_sorted)
cdef SIZE_t n_node_samples = splitter.n_samples
rc = stack.push(0, n_node_samples, 0, _TREE_UNDEFINED, 0, INFINITY, 0)

rc是根節點,在分裂前含有所有的樣本

StackStackRecord都是sklearn自己寫的數據結構

is_leaf = (depth >= max_depth or
           n_node_samples < min_samples_split or
           n_node_samples < 2 * min_samples_leaf or
           weighted_n_node_samples < 2 * min_weight_leaf)
is_leaf = (is_leaf or (impurity <= min_impurity_split))

滿足以上條件直接停止分裂

sklearn.tree._splitter.BestSplitter

sklearn.tree._splitter.BestSplitter#node_split


scikit-learn uses an optimised version of the CART algorithm; however, scikit-learn implementation does not support categorical variables for now.

在這裏插入圖片描述

發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章