24.决策树剪枝
约 1812 字大约 6 分钟
2025-09-20
决策数复杂度容易过高,且容易过拟合,所以我们需要剪枝来降低决策树的复杂度,使其更加准确.
手段:
限制深度(节点层数/树的高度)
限制广度(叶子节点个数)
import numpy as np
from sklearn.datasets import load_iris
from sklearn.tree import DecisionTreeClassifier
iris = load_iris()
X = iris.data[:,:2]
y = iris.target
dtc = DecisionTreeClassifier()
dtc.fit(X, y)运行结果
DecisionTreeClassifier()
plot_decision_boundary(X,y,dtc)
from sklearn.tree import plot_tree
plot_tree(dtc)运行结果
[Text(0.36890243902439024, 0.9615384615384616, 'x[0] <= 5.45\ngini = 0.667\nsamples = 150\nvalue = [50, 50, 50]'), Text(0.0975609756097561, 0.8846153846153846, 'x[1] <= 2.8\ngini = 0.237\nsamples = 52\nvalue = [45, 6, 1]'), Text(0.04878048780487805, 0.8076923076923077, 'x[0] <= 4.7\ngini = 0.449\nsamples = 7\nvalue = [1, 5, 1]'), Text(0.024390243902439025, 0.7307692307692307, 'gini = 0.0\nsamples = 1\nvalue = [1, 0, 0]'), Text(0.07317073170731707, 0.7307692307692307, 'x[0] <= 4.95\ngini = 0.278\nsamples = 6\nvalue = [0, 5, 1]'), Text(0.04878048780487805, 0.6538461538461539, 'x[1] <= 2.45\ngini = 0.5\nsamples = 2\nvalue = [0, 1, 1]'), Text(0.024390243902439025, 0.5769230769230769, 'gini = 0.0\nsamples = 1\nvalue = [0, 1, 0]'), Text(0.07317073170731707, 0.5769230769230769, 'gini = 0.0\nsamples = 1\nvalue = [0, 0, 1]'), Text(0.0975609756097561, 0.6538461538461539, 'gini = 0.0\nsamples = 4\nvalue = [0, 4, 0]'), Text(0.14634146341463414, 0.8076923076923077, 'x[0] <= 5.35\ngini = 0.043\nsamples = 45\nvalue = [44, 1, 0]'), Text(0.12195121951219512, 0.7307692307692307, 'gini = 0.0\nsamples = 39\nvalue = [39, 0, 0]'), Text(0.17073170731707318, 0.7307692307692307, 'x[1] <= 3.2\ngini = 0.278\nsamples = 6\nvalue = [5, 1, 0]'), Text(0.14634146341463414, 0.6538461538461539, 'gini = 0.0\nsamples = 1\nvalue = [0, 1, 0]'), Text(0.1951219512195122, 0.6538461538461539, 'gini = 0.0\nsamples = 5\nvalue = [5, 0, 0]'), Text(0.6402439024390244, 0.8846153846153846, 'x[0] <= 6.15\ngini = 0.546\nsamples = 98\nvalue = [5, 44, 49]'), Text(0.4146341463414634, 0.8076923076923077, 'x[1] <= 3.45\ngini = 0.508\nsamples = 43\nvalue = [5, 28, 10]'), Text(0.3902439024390244, 0.7307692307692307, 'x[0] <= 5.75\ngini = 0.388\nsamples = 38\nvalue = [0, 28, 10]'), Text(0.24390243902439024, 0.6538461538461539, 'x[1] <= 2.85\ngini = 0.208\nsamples = 17\nvalue = [0, 15, 2]'), Text(0.21951219512195122, 0.5769230769230769, 'x[0] <= 5.55\ngini = 0.278\nsamples = 12\nvalue = [0, 10, 2]'), Text(0.1951219512195122, 0.5, 'gini = 0.0\nsamples = 5\nvalue = [0, 5, 0]'), Text(0.24390243902439024, 0.5, 'x[1] <= 2.55\ngini = 0.408\nsamples = 7\nvalue = [0, 5, 2]'), Text(0.1951219512195122, 0.4230769230769231, 'x[0] <= 5.65\ngini = 0.5\nsamples = 2\nvalue = [0, 1, 1]'), Text(0.17073170731707318, 0.34615384615384615, 'gini = 0.0\nsamples = 1\nvalue = [0, 1, 0]'), Text(0.21951219512195122, 0.34615384615384615, 'gini = 0.0\nsamples = 1\nvalue = [0, 0, 1]'), Text(0.2926829268292683, 0.4230769230769231, 'x[0] <= 5.65\ngini = 0.32\nsamples = 5\nvalue = [0, 4, 1]'), Text(0.2682926829268293, 0.34615384615384615, 'x[1] <= 2.75\ngini = 0.5\nsamples = 2\nvalue = [0, 1, 1]'), Text(0.24390243902439024, 0.2692307692307692, 'gini = 0.0\nsamples = 1\nvalue = [0, 1, 0]'), Text(0.2926829268292683, 0.2692307692307692, 'gini = 0.0\nsamples = 1\nvalue = [0, 0, 1]'), Text(0.3170731707317073, 0.34615384615384615, 'gini = 0.0\nsamples = 3\nvalue = [0, 3, 0]'), Text(0.2682926829268293, 0.5769230769230769, 'gini = 0.0\nsamples = 5\nvalue = [0, 5, 0]'), Text(0.5365853658536586, 0.6538461538461539, 'x[1] <= 3.1\ngini = 0.472\nsamples = 21\nvalue = [0, 13, 8]'), Text(0.5121951219512195, 0.5769230769230769, 'x[1] <= 2.95\ngini = 0.488\nsamples = 19\nvalue = [0, 11, 8]'), Text(0.4634146341463415, 0.5, 'x[1] <= 2.85\ngini = 0.459\nsamples = 14\nvalue = [0, 9, 5]'), Text(0.43902439024390244, 0.4230769230769231, 'x[0] <= 5.9\ngini = 0.486\nsamples = 12\nvalue = [0, 7, 5]'), Text(0.36585365853658536, 0.34615384615384615, 'x[1] <= 2.65\ngini = 0.5\nsamples = 6\nvalue = [0, 3, 3]'), Text(0.34146341463414637, 0.2692307692307692, 'gini = 0.0\nsamples = 1\nvalue = [0, 1, 0]'), Text(0.3902439024390244, 0.2692307692307692, 'x[1] <= 2.75\ngini = 0.48\nsamples = 5\nvalue = [0, 2, 3]'), Text(0.36585365853658536, 0.19230769230769232, 'gini = 0.5\nsamples = 4\nvalue = [0, 2, 2]'), Text(0.4146341463414634, 0.19230769230769232, 'gini = 0.0\nsamples = 1\nvalue = [0, 0, 1]'), Text(0.5121951219512195, 0.34615384615384615, 'x[1] <= 2.65\ngini = 0.444\nsamples = 6\nvalue = [0, 4, 2]'), Text(0.4878048780487805, 0.2692307692307692, 'x[0] <= 6.05\ngini = 0.444\nsamples = 3\nvalue = [0, 1, 2]'), Text(0.4634146341463415, 0.19230769230769232, 'gini = 0.5\nsamples = 2\nvalue = [0, 1, 1]'), Text(0.5121951219512195, 0.19230769230769232, 'gini = 0.0\nsamples = 1\nvalue = [0, 0, 1]'), Text(0.5365853658536586, 0.2692307692307692, 'gini = 0.0\nsamples = 3\nvalue = [0, 3, 0]'), Text(0.4878048780487805, 0.4230769230769231, 'gini = 0.0\nsamples = 2\nvalue = [0, 2, 0]'), Text(0.5609756097560976, 0.5, 'x[0] <= 5.95\ngini = 0.48\nsamples = 5\nvalue = [0, 2, 3]'), Text(0.5365853658536586, 0.4230769230769231, 'gini = 0.5\nsamples = 2\nvalue = [0, 1, 1]'), Text(0.5853658536585366, 0.4230769230769231, 'x[0] <= 6.05\ngini = 0.444\nsamples = 3\nvalue = [0, 1, 2]'), Text(0.5609756097560976, 0.34615384615384615, 'gini = 0.0\nsamples = 1\nvalue = [0, 0, 1]'), Text(0.6097560975609756, 0.34615384615384615, 'gini = 0.5\nsamples = 2\nvalue = [0, 1, 1]'), Text(0.5609756097560976, 0.5769230769230769, 'gini = 0.0\nsamples = 2\nvalue = [0, 2, 0]'), Text(0.43902439024390244, 0.7307692307692307, 'gini = 0.0\nsamples = 5\nvalue = [5, 0, 0]'), Text(0.8658536585365854, 0.8076923076923077, 'x[0] <= 7.05\ngini = 0.413\nsamples = 55\nvalue = [0, 16, 39]'), Text(0.8414634146341463, 0.7307692307692307, 'x[1] <= 2.4\ngini = 0.467\nsamples = 43\nvalue = [0, 16, 27]'), Text(0.8170731707317073, 0.6538461538461539, 'gini = 0.0\nsamples = 2\nvalue = [0, 2, 0]'), Text(0.8658536585365854, 0.6538461538461539, 'x[0] <= 6.95\ngini = 0.45\nsamples = 41\nvalue = [0, 14, 27]'), Text(0.8414634146341463, 0.5769230769230769, 'x[1] <= 3.15\ngini = 0.439\nsamples = 40\nvalue = [0, 13, 27]'), Text(0.7317073170731707, 0.5, 'x[0] <= 6.55\ngini = 0.471\nsamples = 29\nvalue = [0, 11, 18]'), Text(0.6829268292682927, 0.4230769230769231, 'x[1] <= 2.95\ngini = 0.375\nsamples = 16\nvalue = [0, 4, 12]'), Text(0.6585365853658537, 0.34615384615384615, 'x[0] <= 6.45\ngini = 0.444\nsamples = 12\nvalue = [0, 4, 8]'), Text(0.6341463414634146, 0.2692307692307692, 'x[1] <= 2.85\ngini = 0.397\nsamples = 11\nvalue = [0, 3, 8]'), Text(0.5853658536585366, 0.19230769230769232, 'x[1] <= 2.6\ngini = 0.219\nsamples = 8\nvalue = [0, 1, 7]'), Text(0.5609756097560976, 0.11538461538461539, 'gini = 0.5\nsamples = 2\nvalue = [0, 1, 1]'), Text(0.6097560975609756, 0.11538461538461539, 'gini = 0.0\nsamples = 6\nvalue = [0, 0, 6]'), Text(0.6829268292682927, 0.19230769230769232, 'x[0] <= 6.25\ngini = 0.444\nsamples = 3\nvalue = [0, 2, 1]'), Text(0.6585365853658537, 0.11538461538461539, 'gini = 0.0\nsamples = 1\nvalue = [0, 1, 0]'), Text(0.7073170731707317, 0.11538461538461539, 'x[0] <= 6.35\ngini = 0.5\nsamples = 2\nvalue = [0, 1, 1]'), Text(0.6829268292682927, 0.038461538461538464, 'gini = 0.0\nsamples = 1\nvalue = [0, 0, 1]'), Text(0.7317073170731707, 0.038461538461538464, 'gini = 0.0\nsamples = 1\nvalue = [0, 1, 0]'), Text(0.6829268292682927, 0.2692307692307692, 'gini = 0.0\nsamples = 1\nvalue = [0, 1, 0]'), Text(0.7073170731707317, 0.34615384615384615, 'gini = 0.0\nsamples = 4\nvalue = [0, 0, 4]'), Text(0.7804878048780488, 0.4230769230769231, 'x[0] <= 6.65\ngini = 0.497\nsamples = 13\nvalue = [0, 7, 6]'), Text(0.7560975609756098, 0.34615384615384615, 'gini = 0.0\nsamples = 2\nvalue = [0, 2, 0]'), Text(0.8048780487804879, 0.34615384615384615, 'x[1] <= 2.65\ngini = 0.496\nsamples = 11\nvalue = [0, 5, 6]'), Text(0.7804878048780488, 0.2692307692307692, 'gini = 0.0\nsamples = 1\nvalue = [0, 0, 1]'), Text(0.8292682926829268, 0.2692307692307692, 'x[1] <= 2.9\ngini = 0.5\nsamples = 10\nvalue = [0, 5, 5]'), Text(0.8048780487804879, 0.19230769230769232, 'gini = 0.0\nsamples = 1\nvalue = [0, 1, 0]'), Text(0.8536585365853658, 0.19230769230769232, 'x[0] <= 6.75\ngini = 0.494\nsamples = 9\nvalue = [0, 4, 5]'), Text(0.8048780487804879, 0.11538461538461539, 'x[1] <= 3.05\ngini = 0.48\nsamples = 5\nvalue = [0, 3, 2]'), Text(0.7804878048780488, 0.038461538461538464, 'gini = 0.5\nsamples = 2\nvalue = [0, 1, 1]'), Text(0.8292682926829268, 0.038461538461538464, 'gini = 0.444\nsamples = 3\nvalue = [0, 2, 1]'), Text(0.9024390243902439, 0.11538461538461539, 'x[0] <= 6.85\ngini = 0.375\nsamples = 4\nvalue = [0, 1, 3]'), Text(0.8780487804878049, 0.038461538461538464, 'gini = 0.0\nsamples = 1\nvalue = [0, 0, 1]'), Text(0.926829268292683, 0.038461538461538464, 'gini = 0.444\nsamples = 3\nvalue = [0, 1, 2]'), Text(0.9512195121951219, 0.5, 'x[0] <= 6.45\ngini = 0.298\nsamples = 11\nvalue = [0, 2, 9]'), Text(0.926829268292683, 0.4230769230769231, 'x[1] <= 3.35\ngini = 0.444\nsamples = 6\nvalue = [0, 2, 4]'), Text(0.9024390243902439, 0.34615384615384615, 'x[0] <= 6.35\ngini = 0.5\nsamples = 4\nvalue = [0, 2, 2]'), Text(0.8780487804878049, 0.2692307692307692, 'gini = 0.5\nsamples = 2\nvalue = [0, 1, 1]'), Text(0.926829268292683, 0.2692307692307692, 'gini = 0.5\nsamples = 2\nvalue = [0, 1, 1]'), Text(0.9512195121951219, 0.34615384615384615, 'gini = 0.0\nsamples = 2\nvalue = [0, 0, 2]'), Text(0.975609756097561, 0.4230769230769231, 'gini = 0.0\nsamples = 5\nvalue = [0, 0, 5]'), Text(0.8902439024390244, 0.5769230769230769, 'gini = 0.0\nsamples = 1\nvalue = [0, 1, 0]'), Text(0.8902439024390244, 0.7307692307692307, 'gini = 0.0\nsamples = 12\nvalue = [0, 0, 12]')]
发现,树的层级过多,同时训练事件特别长,很容易导致过拟合现象产生,所以我们需要剪枝.
决策树剪枝
clf = DecisionTreeClassifier(max_depth=2)
clf.fit(X,y)
plot_tree(clf)运行结果
[Text(0.5, 0.8333333333333334, 'x[0] <= 5.45\ngini = 0.667\nsamples = 150\nvalue = [50, 50, 50]'), Text(0.25, 0.5, 'x[1] <= 2.8\ngini = 0.237\nsamples = 52\nvalue = [45, 6, 1]'), Text(0.125, 0.16666666666666666, 'gini = 0.449\nsamples = 7\nvalue = [1, 5, 1]'), Text(0.375, 0.16666666666666666, 'gini = 0.043\nsamples = 45\nvalue = [44, 1, 0]'), Text(0.75, 0.5, 'x[0] <= 6.15\ngini = 0.546\nsamples = 98\nvalue = [5, 44, 49]'), Text(0.625, 0.16666666666666666, 'gini = 0.508\nsamples = 43\nvalue = [5, 28, 10]'), Text(0.875, 0.16666666666666666, 'gini = 0.413\nsamples = 55\nvalue = [0, 16, 39]')]

clf.score(X,y)运行结果
0.7733333333333333
import matplotlib.pyplot as plt
from sklearn.inspection import DecisionBoundaryDisplay
def plot_decision_boundary(x, y, clf):
"""
绘制分类模型的决策边界
参数:
x (array-like): 特征数据,形状为 (n_samples, n_features)
y (array-like): 标签数据,形状为 (n_samples,)
clf (estimator): 训练好的sklearn分类模型
"""
# 创建图形
fig, ax = plt.subplots(figsize=(10, 6))
# 绘制决策边界
disp = DecisionBoundaryDisplay.from_estimator(
clf,
x,
response_method="predict",
alpha=0.5,
ax=ax,
grid_resolution=300,
cmap=plt.cm.coolwarm
)
# 绘制训练点
scatter = ax.scatter(
x[:, 0], x[:, 1], c=y, edgecolor="k",
alpha=0.8, cmap=plt.cm.coolwarm, s=50
)
# 添加图例
legend1 = ax.legend(*scatter.legend_elements(), title="Classes")
ax.add_artist(legend1)
# 设置标题和坐标轴标签
ax.set_title(f"Decision Boundary of {type(clf).__name__}")
ax.set_xlabel("Feature 1")
ax.set_ylabel("Feature 2")
plt.tight_layout()
plt.show()但是发现模型好像又有些欠拟合
plot_decision_boundary(X,y,clf)
clf = DecisionTreeClassifier(min_samples_split=20)
#n_samples_split 样本数低于这个节点不在划分
# min_samples_leaf 叶子节点样本数低于这个节点不再分裂
clf.fit(X,y)
plot_decision_boundary(X,y,clf)
