JobPlus知识库 IT 工业智能4.0 文章
机器学习实战笔记:树回归

 使用ID3算法构建的决策树有如下问题:

  • 每次选取当前最佳的特征来分割数据,并按照该特征所有可能的取值来切分。也就是说,一个特征有n个取值,那么数据就会被分割成n份。
  • 使用某一特征来分割数据后,该特征在之后的算法执行过程中将不会再起作用,这种切分方式过于迅速。
  • 不能直接处理连续型特征,只有事先将连续型特征转换成离散型,才能使用ID3算法。

      CART算法是一种基于“基尼指数”的决策树构建算法,常用于回归树的构建中。回归树与分类树思想基本一致,但叶结点的数据类型不是离散型而是连续型。

      树回归的一般流程:

  1. 收集数据;
  2. 准备数据:需要数值型的数据,标称型数据应该映射成二值型数据;
  3. 分析数据:绘出数据的二维可视化显示结果,以字典方式生成树;
  4. 训练算法;
  5. 测试算法;
  6. 使用算法:用训练出的树做预测;

连续和离散型特征的树的构建

       使用字典来存储树的数据结构,该字典包含以下4个元素:

  • 待切分的特征
  • 待切分的特征值
  • 右子树。当不需要切分时,也可以是单个值
  • 左子树:与右子树类似

        函数createTree()的伪代码如下:

[python] 

  1. def createTree():  
  2.     找到最佳的待切分特征:  
  3.         if 该节点不能再分:  
  4.             将该节点村委叶节点  
  5.         执行二元切分  
  6.         在右子树调用createTree()  
  7.         在左子树调用createTree()  

        构建树的代码如下:

[python]

  1. def loadDataSet(fileName):  
  2.     dataMat=[]  
  3.     fr=open(fileName)  
  4.     for line in fr.readlines():  
  5.         curLine=line.strip().split('\t')  
  6.         fltLine=map(float,curLine)    #将每一行映射成浮点数  
  7.         dataMat.append(fltLine)  
  8.     return dataMat  
  9.   
  10. """切分数据集(注意这里书上错了)"""  
  11. def binSplitDataSet(dataSet,feature,value):#参数:数据集、待切分的特征、该特征的某个值  
  12.     #通过数组过滤方式将数据集切分得到两个子集返回  
  13.     mat0=dataSet[nonzero(dataSet[:,feature]>value)[0],:]    #选出指定特征feature满足大于特征值value的样本数据  
  14.     mat1=dataSet[nonzero(dataSet[:, feature]<=value)[0],:]  #选出指定特征feature满足小于等于特征值value的样本数据  
  15.     return mat0,mat1  
  16.   
  17. """树构建函数"""  
  18. def createTree(dataSet,leafType=regLeaf,errType=regErr,ops=(1,4)):  
  19.     #leafType是对创建叶结点的引用;errType是对总误差方差计算函数的引用;ops是用户定义的参数构成的元组,用于树构建  
  20.     feat,val=chooseBestSplit(dataSet,leafType,errType,ops)     #根据基尼指数选择最好的特征用于划分  
  21.     if feat==None:  
  22.         return val  
  23.     retTree={}  
  24.     retTree['spInd']=feat  
  25.     retTree['spVal']=val  
  26.     lSet,rSet=binSplitDataSet(dataSet,feat,val)  
  27.     retTree['left']=createTree(lSet,leafType,errType,ops)   #递归构建左子树  
  28.     retTree['right']=createTree(lSet,leafType,errType,ops)  #递归构建右子树  
  29.     return retTree  

        choosBestSplit函数(用于选择最好特征分割)的伪代码如下:

[python] 

  1. def chooseBestSplit():  
  2.     对每个特征:  
  3.         对每个特征值:  
  4.             将数据集切分成两份  
  5.             计算切分的误差  
  6.             如果当前误差小于当前最小误差,那么将当前切分设定为最佳切分并更新最小误差  
  7.     return 最佳切分的特征和阈值  

        代码实现如下:

[python] 

  1. """负责计算目标变量的平方误差"""  
  2. def regErr(dataSet):  
  3.     return var(dataSet[:,-1])*shape(dataSet)[0]    #均方差函数var(),因为要返回总方差,故要乘以样本个数  

[python] 

  1. """负责生成叶结点"""  
  2. #当chooseBestSplit函数确定不再对数据进行切分时,将调用该函数来得到叶结点的模型。在回归树中,该模型其实就是目标变量的均值  
  3. def regLeaf(dataSet):  
  4.     return mean(dataSet[:,-1])  

[python]

  1. def chooseBestSplit(dataSet,leafType=regLeaf,errType=regErr,ops=(1,4)):  
  2.     #tolS和tolN用于控制函数的停止时机  
  3.     tolS=ops[0]   #容许的误差下降值  
  4.     tolN=ops[1]   #切分的最少样本数  
  5.     #将预测值y(特征值)/分类类别转化成一个列表(dataSet[:,-1].T.tolist()[0])  
  6.     #set函数将这个列表转化成集合,即特征值不同的才会被放入集合  
  7.     #len计算集合长度,如果为1说明不同剩余特征值的数目为1,那么就不需要在切分,只要直接返回  
  8.     if len(set(dataSet[:,-1].T.tolist()[0]))==1:  
  9.         return None,leafType(dataSet)    #用leafType对数据集生成叶结点  
  10.     m,n=shape(dataSet)   #n是特征数和y的和  
  11.     S=errType(dataSet)  
  12.     bestS=inf  
  13.     bestIndex=0  
  14.     bestValue=0  
  15.     #在所有可能的特征及其可能的取值上遍历  
  16.     for featIndex in range(n-1):  
  17.         for splitVal in set(dataSet[:,featIndex].T.tolist()[0]):  
  18.             mat0,mat1=binSplitDataSet(dataSet,featIndex,splitVal)  #切分数据集  
  19.             if(shape(mat0)[0]<tolN)or(shape(mat1)[0]<tolN):   #判断是否还需继续切分  
  20.                 continue  
  21.             newS=errType(mat0)+errType(mat1)  
  22.             if newS<bestS:  
  23.                 bestIndex=featIndex  
  24.                 bestValue=splitVal  
  25.                 bestS=newS  
  26.     if (S-bestS) < tolS:  
  27.         return None,leafType(dataSet)  
  28.     mat0, mat1 = binSplitDataSet(dataSet, featIndex, splitVal)  
  29.     if (shape(mat0)[0] < tolN) or (shape(mat1)[0] < tolN):   #如果提前终止条件均不满足则返回切分特征和特征值  
  30.         return None,leafType(dataSet)  
  31.     return bestIndex,bestValue  

[python]

  1. if __name__=='__main__':  
  2.     myDat=loadDataSet('ex00.txt')  
  3.     myMat=mat(myDat)  
  4.     regTree=createTree(myMat)  
  5.     print(regTree)  

        这里要注意python3中map返回的的类型已经不是list而是可迭代对象,故要将map返回的对象做list处理才能使用。


如果觉得我的文章对您有用,请随意打赏。您的支持将鼓励我继续创作!

¥ 打赏支持
127人赞 举报
分享到
用户评价(0)

暂无评价,你也可以发布评价哦:)

扫码APP

扫描使用APP

扫码使用

扫描使用小程序