JobPlus知识库 IT 工业智能4.0 文章
机器学习实战之决策树

决策时是一个分类算法。本文主要讲了一下决策树的构造以及用绘图的形式把决策树绘画出来。

决策树的构造

本文使用ID3算法来划分数据集,通过计算每一个特征的香农熵来选取最优划分数据集的特征,之后在递归的构造决策树来遍历每一个特征。


下面公式是计算香农熵,p(xi)是选择该分类的概率 ,n是分类的数目。


算法步骤:

  1. 利用calcShannonEnt函数计算原始数据的原始香农熵,即最后的一个特征来划分计算出来的香农熵,如下方的'yes','no'的特征计算出来的香农熵。

  2. 利用 chooseBestFeatureToSplit函数来计算每一个特征的香农熵,选取最大的香农熵的那一个特征来进行数据划分和构造决策树.

  3. 利用splitDataSet函数将该特征去掉的数据集继续遍历递归的计算每个特征的香农熵构造决策树。

决策树的存储:

  1. 使用字典来构造存储决策树的信息,之后可以用pickle模块来存储决策树。

下面是构造决策树的具体代码。 使用的是python3。


[python]

  1. # -*- coding: UTF-8 -*-  
  2. from math import log  
  3. from treePlotter import retrieveTree, createPlot  
  4. import operator  
  5. # ID3决策树算法  
  6.   
  7. # 测试数据集  
  8. def createDataSet():  # 列表每一项最后一列为类别标签  
  9.     dataSet = [[1, 1, 'yes'],  
  10.                [1, 1, 'yes'],  
  11.                [1, 0, 'no'],  
  12.                [0, 1, 'no'],  
  13.                [0, 1, 'yes']  
  14.                ]  
  15.     labels = ['no surfacing', 'flippers']  
  16.     return dataSet, labels  
  17.   
  18.   
  19. # 计算给定数据的香农熵    
  20. def calcShannonEnt(dataSet):  
  21.     numEntries = len(dataSet)  
  22.     labelCounts = {}  
  23.     for featVec in dataSet:  
  24.         currentLabel = featVec[-1]  
  25.         if currentLabel not in labelCounts.keys():  
  26.             labelCounts[currentLabel] = 0  
  27.         labelCounts[currentLabel] += 1  
  28.     shannonEnt = 0.0  
  29.     for key in labelCounts:  
  30.         prob = float(labelCounts[key]) / numEntries  
  31.         shannonEnt -= prob * log(prob, 2)  
  32.     return shannonEnt  
  33.   
  34.   
  35. # 划分数据集    按axis列来划分 之后将axis列去掉  
  36. def splitDataSet(dataSet, axis, value):  
  37.     retDataSet = []  
  38.     for featVec in dataSet:  
  39.         if featVec[axis] == value:  
  40.             reducedFeatVec = featVec[:axis]  
  41.             reducedFeatVec.extend(featVec[axis + 1:])  
  42.             retDataSet.append(reducedFeatVec)  # 注意append和extend 方法的区别  
  43.     return retDataSet  
  44.   
  45.   
  46. # 选择最好的数据集方式划分    香农熵越大代表选择该特征分类更好,选取最大的香农熵。  
  47. def chooseBestFeatureToSplit(dataSet):  
  48.     numFeatures = len(dataSet[0]) - 1  # 特征数量  
  49.     baseEntropy = calcShannonEnt(dataSet)  # 原始香农熵  
  50.     bestInfoGain = 0.0  
  51.     bestFeature = -1  
  52.     for i in range(numFeatures):  
  53.         featList = [example[i] for example in dataSet]  
  54.         uniqueVals = set(featList)  # set() 函数创建一个无序不重复元素集  
  55.         newEntropy = 0.0  
  56.         for value in uniqueVals:  
  57.             subDataSet = splitDataSet(dataSet, i, value)  
  58.             prob = len(subDataSet) / float(len(dataSet))  
  59.             newEntropy += prob * calcShannonEnt((subDataSet))  
  60.         infoGain = baseEntropy - newEntropy  
  61.         if (infoGain > bestInfoGain):  
  62.             bestInfoGain = infoGain  
  63.             bestFeature = i  
  64.     return bestFeature  
  65.   
  66.   
  67. # 返回出现次数最多的分类名称  
  68. def majorityCnt(classList):  
  69.     classCount = {}  
  70.     for vote in classList:  
  71.         if vote not in classCount.keys(): classCount[vote] = 0  
  72.         classCount[vote] += 1  
  73.     # sortedClassCount = sorted(classCount.iteritems (), key=operator.itemgetter(1), reverse=True)  # python 2.7的写法  
  74.     sortedClassCount = sorted(classCount.items(), key=operator.itemgetter(1),  
  75.                               reverse=True)  # python3 中iteritems 改为 items  
  76.     return sortedClassCount[0][0]  
  77.   
  78.   
  79. # 创建决策树  
  80. def createTree(dataSet, labels):  
  81.     classList = [example[-1] for example in dataSet]  
  82.     if classList.count(classList[0]) == len(classList):  # 当只剩一类的时候 直接返回该标签  
  83.         return classList[0]  
  84.     if len(dataSet[0]) == 1:  
  85.         return majorityCnt(classList)  
  86.   
  87.     bestFeat = chooseBestFeatureToSplit(dataSet)  # 选择最优的分类即香农熵最大的特征。  
  88.     bestFeatLabel = labels[bestFeat]  
  89.     myTree = {bestFeatLabel: {}}  # 字典型存储树  
  90.     del (labels[bestFeat])  
  91.     featValues = [example[bestFeat] for example in dataSet]  
  92.     uniqueVals = set(featValues)  
  93.     for value in uniqueVals:  
  94.         subLabels = labels[:]  
  95.         myTree[bestFeatLabel][value] = createTree(splitDataSet(dataSet, bestFeat, value), subLabels)  
  96.     return myTree  
  97.   
  98.   
  99. def classify(inputTree, featLabels, testVec):  
  100.     firstStr = list(inputTree.keys())[0]  
  101.     secondDict = inputTree[firstStr]  
  102.     featIndex = featLabels.index(firstStr)  
  103.     for key in secondDict.keys():  
  104.         if testVec[featIndex] == key:  
  105.             if type(secondDict[key]).__name__ == 'dict':  
  106.                 classLabel = classify(secondDict[key], featLabels, testVec)  
  107.             else:  
  108.                 classLabel = secondDict[key]  
  109.     return classLabel  
  110.   
  111.   
  112. # 决策树的存储  
  113. def storeTree(inputTree, filename):  
  114.     import pickle  
  115.     with open(filename, 'wb') as fw:  
  116.         pickle.dump(inputTree, fw)  
  117.   
  118.   
  119. def grabTree(filename):  
  120.     import pickle  
  121.     fr = open(filename, 'rb')  # 注意加上‘rb’,否则会读不出来  
  122.     return pickle.load(fr)  
  123.   
  124.   
  125. if __name__ == '__main__':  
  126.     # myDat, labels = createDataSet()  
  127.     # myTree = retrieveTree(0)  
  128.     # print(classify(myTree, labels, [1, 0]))  
  129.     # storeTree(myTree, 'classifierStorage.txt')  
  130.     # print(grabTree('classifierStorage.txt'))  
  131.     # # print(splitDataSet(myDat,0,1))  
  132.     # # print(chooseBestFeatureToSplit(myDat))  
  133.     # # print(createTree(myDat, labels))  
  134.       
  135.     # 隐形眼镜的例子  
  136.     fr = open('lenses.txt')  
  137.     lenses = [inst.strip().split('\t') for inst in fr.readlines()]  
  138.     lensesLabels = ['age', 'prescript', 'astigmatic', 'tearRate']  
  139.     lensesTree = createTree(lenses, lensesLabels)  
  140.     createPlot(lensesTree)  


决策树的树形图绘制


构造出决策树了然而字典的表示形式不易于理解,接下来使用Matplotlib库创建树形图。

[python] 

  1. # coding:utf-8  
  2. import matplotlib.pyplot as plt  
  3. import matplotlib  
  4.   
  5. plt.rcParams['font.sans-serif'] = ['SimHei']  
  6. plt.rcParams['axes.unicode_minus'] = False  
  7.   
  8. decisionNode = dict(boxstyle="sawtooth", fc="0.8")  
  9. leafNode = dict(boxstyle="round4", fc="0.8")  
  10. arrow_args = dict(arrowstyle="<-")  
  11.   
  12.   
  13. def plotNode(nodeText, centerPt, parentPt, nodeType):  
  14.     createPlot.ax1.annotate(nodeText, xy=parentPt, xycoords='axes fraction', xytext=centerPt,  
  15.                             textcoords='axes fraction', va='center', ha='center', bbox=nodeType, arrowprops=arrow_args)  
  16.   
  17.   
  18. def createPlot(inTree):  
  19.     fig = plt.figure(1, facecolor='white')  
  20.     fig.clf()  
  21.     axprops = dict(xticks=[], yticks=[])  # 创建一个字典  
  22.     createPlot.ax1 = plt.subplot(111, frameon=False, **axprops)  
  23.     plotTree.totalW = float(getNumLeafs(inTree))  
  24.     plotTree.totalD = float(getTreeDepth(inTree))  
  25.     plotTree.xOff = -0.5 / plotTree.totalW  
  26.     plotTree.yOff = 1.0  # x偏移  
  27.     plotTree(inTree, (0.5, 1.0), '')  # 绘制决策树  
  28.     plt.show()  
  29.     #     plotNode(U'决策节点', (0.5, 0.1), (0.1, 0.5), decisionNode)  
  30.     #     plotNode(U'叶节点', (0.8, 0.1), (0.3, 0.8), leafNode)  
  31.     #     plt.show()  
  32.   
  33.       
  34. # 获取叶节点的数目  
  35. def getNumLeafs(myTree):  
  36.     numLeafs = 0  
  37.     firstStr = list(myTree.keys())[0]  # python3 和书中代码不一样,因为python3改变了dict.keys,返回 的是一个对象。  
  38.     secondDict = myTree[firstStr]  
  39.     for key in secondDict.keys():  
  40.         if type(secondDict[key]).__name__ == 'dict':  
  41.             numLeafs += getNumLeafs(secondDict[key])  
  42.         else:  
  43.             numLeafs += 1  
  44.     return numLeafs  
  45.   
  46.   
  47. # 获取树的层数  
  48. def getTreeDepth(myTree):  
  49.     maxDepth = 0  
  50.     firstStr = list(myTree.keys())[0]  
  51.     secondDict = myTree[firstStr]  
  52.     for key in secondDict.keys():  
  53.         if type(secondDict[key]).__name__ == 'dict':  
  54.             thisDepth = 1 + getTreeDepth(secondDict[key])  
  55.         else:  
  56.             thisDepth = 1  
  57.         if thisDepth > maxDepth:  
  58.             maxDepth = thisDepth  
  59.     return maxDepth  
  60.   
  61.   
  62. # 用于测试  
  63. def retrieveTree(i):  
  64.     listOfTrees = [{'no surfacing': {0: 'no', 1: {'flippers': {0: 'no', 1: 'yes'}}}},  
  65.                    {'no surfacing': {0: 'no', 1: {'flippers': {0: {'head': {0: 'no', 1: 'yes'}}, 1: 'no'}}}}]  
  66.     return listOfTrees[i]  
  67.   
  68.   
  69. def plotMidText(cntrPt, parentPt, txtString):  # cntrPt、parentPt 用于计算标注位置 txtString 标注的内容  
  70.     xMid = (parentPt[0] - cntrPt[0]) / 2.0 + cntrPt[0]  # 计算标注位置  
  71.     yMid = (parentPt[1] - cntrPt[1]) / 2.0 + cntrPt[1]  
  72.     createPlot.ax1.text(xMid, yMid, txtString)  
  73.   
  74.   
  75. def plotTree(myTree, parentPt, nodeTxt):  
  76.     numLeafs = getNumLeafs(myTree)  # 获取决策树叶结点数目,决定了树的宽度  
  77.     depth = getTreeDepth(myTree)  # 获取决策树层数  
  78.     firstStr = next(iter(myTree))  # 下个字典  
  79.     cntrPt = (plotTree.xOff + (1.0 + float(numLeafs)) / 2.0 / plotTree.totalW, plotTree.yOff)  # 中心位置  
  80.     plotMidText(cntrPt, parentPt, nodeTxt)  # 标注有向边属性值  
  81.     plotNode(firstStr, cntrPt, parentPt, decisionNode)  # 绘制结点  
  82.     secondDict = myTree[firstStr]  # 下一个字典,也就是继续绘制子结点  
  83.     plotTree.yOff = plotTree.yOff - 1.0 / plotTree.totalD  # y偏移  
  84.     for key in secondDict.keys():  
  85.         if type(secondDict[key]).__name__ == 'dict':  # 测试该结点是否为字典,如果不是字典,代表此结点为叶子结点  
  86.             plotTree(secondDict[key], cntrPt, str(key))  # 不是叶结点,递归调用继续绘制  
  87.         else:  # 如果是叶结点,绘制叶结点,并标注有向边属性值  
  88.             plotTree.xOff = plotTree.xOff + 1.0 / plotTree.totalW  
  89.             plotNode(secondDict[key], (plotTree.xOff, plotTree.yOff), cntrPt, leafNode)  
  90.             plotMidText((plotTree.xOff, plotTree.yOff), cntrPt, str(key))  
  91.     plotTree.yOff = plotTree.yOff + 1.0 / plotTree.totalD  
  92.   
  93.   
  94. if __name__ == '__main__':  
  95.     pass  
  96.     # # retrieveTree(1)  
  97.     # myTree = retrieveTree(0)  
  98.     # # print(getNumLeafs(myTree))  
  99.     # # print(getTreeDepth(myTree))  
  100.     # createPlot(myTree)  


总结决策树的优点

  • 决策树易于理解和解释,可以可视化.

  • 几乎不需要数据预处理。其他方法经常需要数据标准化,创建虚拟变量和删除缺失值。决策树还不支持缺失值。

  • 使用树的花费(例如预测数据)是训练数据点(data points)数量的对数。

  • 可以同时处理数值变量和分类变量。其他方法大都适用于分析一种变量的集合。

  • 可以处理多值输出变量问题。

  • 使用白盒模型。如果一个情况被观察到,使用逻辑判断容易表示这种规则。相反,如果是黑盒模型(例如人工神经网络),结果会非常难解释。

  • 即使对真实模型来说,假设无效的情况下,也可以较好的适用。

决策树的缺点

  • 决策树可能会产生过多的数据集划分,从而产生过度匹配数据集的问题,可以通过裁剪决策树,合并相邻的无法产生大量信息增益的叶节点,从而消除过度匹配问题。

  • 决策树学习可能创建一个过于复杂的树,并不能很好的预测数据。也就是过拟合。修剪机制(现在不支持),设置一个叶子节点需要的最小样本数量,或者数的最大深度,可以避免过拟合。

  • 决策树可能是不稳定的,因为即使非常小的变异,可能会产生一颗完全不同的树。这个问题通过decision trees with an ensemble来缓解。

  • 学习一颗最优的决策树是一个NP-完全问题under several aspects of optimality and even for simple concepts。因此,传统决策树算法基于启发式算法,例如贪婪算法,即每个节点创建最优决策。这些算法不能产生一个全家最优的决策树。对样本和特征随机抽样可以降低整体效果偏差。

  • 概念难以学习,因为决策树没有很好的解释他们,例如,XOR, parity or multiplexer problems.

  • 如果某些分类占优势,决策树将会创建一棵有偏差的树。因此,建议在训练之前,先抽样使样本均衡。


还有其他决策树的构造算法,最流向的是C4.5和CART。本文是ID3构造算法。


如果觉得我的文章对您有用,请随意打赏。您的支持将鼓励我继续创作!

¥ 打赏支持
448人赞 举报
分享到
用户评价(0)

暂无评价,你也可以发布评价哦:)

扫码APP

扫描使用APP

扫码使用

扫描使用小程序