决策树(Decisiontree)是基于已知各种情况(特征取值)的基础上,通过构建树型决策结构来进行分析的一种方式,是常用的有监督的分类算法(也就是带有标签的训练数据集训练的,比如后文中使用到的训练集中的好瓜坏瓜就是标签,形容瓜的就是特征)
决策树模型(DecisionTreemodel)模拟人类决策过程。
根节点:决策树的起点,代表数据集的整体。
内部节点:表示对某个特征进行的判断或测试,也可以说是类别二选一。
分支:从一个节点到另一个节点的路径,根据特征的取值进行分割,表示一个测试输出。
叶节点:代表最终的决策或预测结果。
选择特征:决策树通过选取最能分割数据的特征来构建内部节点。通常使用信息增益(InformationGain)或基尼系数(GiniImpurity)等标准来衡量特征的重要性,这些标准后面还会谈到。
分裂:根据选定的特征,将数据集分成若干子集,每个子集对应一个特定的特征取值或范围。
递归分裂:对子集重复上述过程,构建子树,直到满足停止条件(如节点纯度达到阈值、最大深度达到、数据量不足等)。
终止条件:当不能再有效分裂时,节点转化为叶节点,叶节点的输出即为分类标签或回归值。
图里搞的很复杂,重点其实就在递归。
要了解决策树的「最优属性」选择,我们需要先了解一个信息论的概念「信息熵(entropy)」,它是消除不确定性所需信息量的度量,也是未知事件可能含有的信息量。
假设数据集\(D\)中有\(y\)类,其中第\(k\)类样本占比为\(p_k\),则信息熵的计算公式如下:
\(ENT(D)=-\sum_{k=1}^{|y|}p_k\log_2p_k\)
\(p_k\)为1时,信息熵最小为0,很明显为必然事件,\(p_k\)为均匀分布(概率相等)时,信息商取最大值(\(p_k=\frac{1}{y}\))\(\log_2(|y|)\)(概率同等,不确定性最大)
还记得我们之前的决策树家族中的ID3吗?构建时用的就是信息增益信息增益(InformationGain),它衡量的是我们选择某个属性进行划分时信息熵的变化(可以理解为基于这个规则划分,不确定性降低的程度)。典型的决策树算法ID3就是基于信息增益来挑选每一节点分支用于划分的属性(特征)的。
这里面的\(D^v\)可能有点难理解,它是将数据集\(D\)根据属性\(a\)的那些取值划分成了\(v\)个子集\(\{D_1,D_2,...,D_v\}\),那划分后的信息熵又是咋来的,其实是一种条件熵\(H(D|a)\),是数据集\(D\)在基于属性\(a\)进行划分后的不确定性。
下面拿一个西瓜的数据集举个例子,一共17个数据,9个好瓜,8个坏瓜
以色泽属性为条件计算信息熵,一共三类色泽:\(青绿,乌黑,浅白\),看看他们在好坏瓜中的占比进行计算
同样的方法,计算其他属性的信息增益为:
对比不同属性,我们发现「纹理」信息增益最大,它就要作为决策树的根节点,可以看到里面被分为三个属性:\(清晰,模糊,稍糊\),也就是下一层的节点要根据这三个属性来看,计算各属性信息增益
图中只给出了纹理=清晰这一个分支的结果,有三个属性信息增益都一样,那么说明这三个特征都是最能分割数据的特征,均作为决策树的节点。纹理=稍糊以及其他属性的计算过程略去了,最后的结果如下图
大家已经了解了信息增益作为特征选择的方法,但信息增益有一个问题,它偏向取值较多的特征。原因是,当特征的取值较多时,根据此特征划分更容易得到纯度更高的子集,因此划分之后的熵更低,由于划分前的熵是一定的。因此信息增益更大,因此信息增益比较偏向取值较多的特征。
那有没有解决这个小问题的方法呢?有的,这就是我们要提到信息增益率(GainRatio),信息增益率相比信息增益,多了一个衡量本身属性的分散程度的部分作为分母,而著名的决策树算法C4.5就是使用它作为划分属性挑选的原则。
\(Grain\_ratio(D,a)=\frac{Gain(D,a)}{IV(a)}\)\(IV(a)=-\sum_{v=1}^{V}\frac{|D^v|}{|D|}\log_2\frac{|D^v|}{|D|}\)
下面那一块就是熵公式的变式,固有熵通过计算特征自身的“熵”,使得信息增益率能够公平地评价特征的分裂能力,不偏向多值特征。
基尼系数\(Gini(D)=\sum_{k=1}^{|y|}\sum_{k^{'}\not=k}p_kp_{k^{'}}=1-\sum_{k=1}^{|y|}p_k^2\)
为什么它可以作为纯度的量度呢?大家可以想象在一个漆黑的袋里摸球,有不同颜色的球,其中第k类占比记作\(p_k\),那两次摸到的球都是第k类的概率就是\(p_k^2\),那两次摸到的球颜色不一致的概率就是\(1-\sump_k^2\),它的取值越小,两次摸球颜色不一致的概率就越小,纯度就越高。
如果我们让决策树一直生长,最后得到的决策树可能很庞大,而且因为对原始数据学习得过于充分会有过拟合的问题。缓解决策树过拟合可以通过剪枝操作完成。而剪枝方式又可以分为:预剪枝和后剪枝。并使用「留出法」进行评估剪枝前后决策树的优劣。
我们来看一个例子,下面的数据集,为了评价决策树模型的表现,会划分出一部分数据作为验证集
在上述西瓜数据集上生成的一颗完整的决策树,如下图所示。
剪枝基本策略包含「预剪枝」和「后剪枝」两个:
预剪枝(pre-pruning):在决策树生长过程中,对每个结点在划分前进行估计,若当前结点的划分不能带来决策树泛化性能的提升,则停止划分并将当前结点标记为叶结点。
后剪枝(post-pruning):先从训练集生成一颗完整的决策树,然后自底向上地对非叶结点进行考察,若将该结点对应的子树替换为叶结点能带来决策树泛化性能的提升,则将该子树替换为叶结点。
根据我们的验证集,如果只按照好坏瓜来进行分的话验证集精度为3/7x100%=42.9%
但是加入决策树的结点后,验证集精度为5/7x100%=71.4%,比没有划分前大,所以不剪枝,最后分出来如下图
和预剪枝一样是判断精度,只是从下面开始,没剪之前精度42.9,如果把结点⑥的标记为好瓜,精度57.1,可以剪
最终结果为
过/欠拟合风险:预剪枝:过拟合风险降低,欠拟合风险增加。后剪枝:过拟合风险降低,欠拟合风险基本不变。
泛化性能:后剪枝通常优于预剪枝。
我们用于学习的数据包含了连续值特征和离散值特征,之前的例子中使用的都是离散值属性(特征),决策树当然也能处理连续值属性,我们来看看它的处理方式。
对于离散取值的特征,决策树的划分方式是:选取一个最合适的特征属性,然后将集合按照这个特征属性的不同值划分为多个子集合,并且不断的重复这种操作的过程。
对于连续值属性,显然我们不能以这些离散值直接进行分散集合,否则每个连续值将会对应一种分类。那我们如何把连续值属性参与到决策树的建立中呢?
因为连续属性的可取值数目不再有限,因此需要连续属性离散化处理,常用的离散化策略是二分法,这个技术也是C4.5中采用的策略。
原始数据很多时候还会出现缺失值,决策树算法也能有效的处理含有缺失值的数据。缺失值处理的基本思路是:样本赋权,权重划分。
权重划分是指将整体权重分配给不同的部分或类别,确保模型能够有效地学习这些部分。例如,在决策树的构建过程中,使用样本的权重来影响节点的分裂决策。
我们来通过下图这份有缺失值的西瓜数据集,看看具体处理方式。
仅通过无缺失值的样例来判断划分属性的优劣,学习开始时,根结点包含样例集\({D}\)中全部17个样例,权重均为1。
\(\widetilde{D^1},\widetilde{D^2},\widetilde{D^3}\)分别表示在属性「色泽」上取值为「青绿」「乌黑」以及「浅白」的样本子集:
再计算其他属性的增益
因此选择「纹理」作为接下来的划分属性。感觉权重可能就体现在那个\(\widetilde{r_v}\)里,就是排除了缺失值的占比
用的是iris数据集,直接用sklearn库
样本数量:150个。特征数量:4个连续特征。类别数量:3个类别,每个类别包含50个样本。数据平衡:每个类别的样本数量相同,均为50个。
fromsklearn.datasetsimportload_irisfromsklearn.model_selectionimporttrain_test_splitfromsklearn.treeimportDecisionTreeClassifierfromsklearnimportmetricsimportmatplotlib.pyplotaspltfromsklearn.treeimportplot_tree#1.加载数据集iris=load_iris()X=iris.datay=iris.target#2.划分训练集和测试集X_train,X_test,y_train,y_test=train_test_split(X,y,test_size=0.3,random_state=42)#3.训练决策树模型clf=DecisionTreeClassifier()clf.fit(X_train,y_train)#4.进行预测y_pred=clf.predict(X_test)#5.评估模型print("Accuracy:",metrics.accuracy_score(y_test,y_pred))print("ClassificationReport:\n",metrics.classification_report(y_test,y_pred))#6.可视化决策树plt.figure(figsize=(12,8))plot_tree(clf,filled=True,feature_names=iris.feature_names,class_names=iris.target_names)plt.show()