JobPlus知识库 IT 工业智能4.0 文章
决策树 python代码实现

决策树任务总结:有n条训练数据,每一条数据格式为[属性1,属性2,…,属性k,结果i],即数据为n*(k+1)的矩阵。 
根据这n条数据生成一颗决策树,当来一条新数据时,能够根据k个属性,代入决策树预测出结果。 
决策树是树状,叶子节点是结果,非叶子节点是决策节点,每一个决策节点是对某个属性的判断。 
而选择哪一个属性作为当前划分属性,则是比较每一个属性划分前后信息熵变化的差异,选差异最大的作为当前划分属性。

trees.py

import math

import operator

#计算信息熵

def calcInformationEntropy(dataSet):

    numOfDataset=float(len(dataSet))

    labels=[onepice[-1] 

for onepice in dataSet]

    uniqueLabels=set(labels)

    Entropy=0.0

    for value in uniqueLabels:

        countValue=labels.count(value)

        prob=float(countValue)/numOfDataset

        Entropy+=-1*prob*math.log(prob,2)

    return Entropy

#按照第i维特征的某一个值划分数据

def splitDataSet(dataSet,featureIndex,value):

    subDataSet=[]

    for line in dataSet:

        if line[featureIndex]==value:

            newline=line[:featureIndex];

            newline.extend(line[featureIndex+1:])

            subDataSet.append(newline)

    return subDataSet

#选择最好的特征,即划分前后信息熵增益最大

def chooseBestFeatureToSplit(dataSet):

    preEntropy=calcInformationEntropy(dataSet)

    numOfFeatures=len(dataSet[0])-1

    bestFeature=-1

    maxEntropyGain=0.0

    for i in range(numOfFeatures):

        featureList=[example[i] for example in dataSet]

        uniqueFeatures=set(featureList)

        postEntropy=0.0

        for value in uniqueFeatures:

            subDataSet=splitDataSet(dataSet,i,value)

            prob=float(len(subDataSet))/float(len(dataSet))

            postEntropy+=prob*calcInformationEntropy(subDataSet)

        entropyGain=preEntropy-postEntropy

        if maxEntropyGain<entropyGain:

            maxEntropyGain=entropyGain

            bestFeature=i

    return bestFeature

#当特征已经用完,投票决定剩下的样本,少数服从多数

def majorityCnt(classList):

    classCount={}

    for vote in classList:

        if vote not in classCount.keys(): classCount[vote] = 0

        classCount[vote] += 1

    sortedClassCount = sorted(classCount.iteritems(), key=operator.itemgetter(1), reverse=True)

    return sortedClassCount[0][0]

#根据训练样本,生成树

def createTree(dataSet,labels):

    classList=[example[-1] 

for example in dataSet]

    if classList.count(classList[0])==len(classList):

        return classList[0]

    if len(dataSet[0])==1:

        return majorityCnt(classList)

    bestFeature=chooseBestFeatureToSplit(dataSet)

    features=[example[bestFeature] for example in dataSet]

    uniqueFeatures=set(features)

    curLabel=labels[bestFeature]

    dcTree={}

    subLabels=labels[:]

    del(subLabels[bestFeature])

    for value in uniqueFeatures:

        dcTree[value]=createTree(splitDataSet(dataSet,bestFeature,value),subLabels)

    myTree={}

    myTree[curLabel]=dcTree

    return myTree

#决策子函数

def decision(Tree,inputFeature):

    firstNode=list(Tree.keys())[0]

    value=inputFeature[firstNode]

    if isinstance(Tree[firstNode][value],dict):

        return decision(Tree[firstNode][value],inputFeature)

    return Tree[firstNode][value]

#输入一条数据 ,预测他是什么类型的

def prediction(Tree,inputFeatureVec,labelsVec):

    if len(labelsVec)!=len(inputFeatureVec):

        return "error input"

    lenVec=len(labelsVec)

    inputDict={}

    for i in range(lenVec):

        inputDict[labelsVec[i]]=inputFeatureVec[i]

    return decision(Tree,inputDict)

#把树存储下来

    def storeTree(inputTree,filename):

    import pickle

    fw = open(filename,'w')

    pickle.dump(inputTree,fw)

    fw.close()

#把树还原出来

def grabTree(filename):

    import pickle

    fr = open(filename)

    return pickle.load(fr)


test.py

import trees

import pandas as pd 

df=pd.read_csv("lenses.txt",header=None,sep='\t') 

labels=['age','prescript','astigmstic','tearRate'] 

dataSet=[]

for i in range(len(df)):

    dataSet.append(list(df.loc[i][:])) 

myTree=trees.createTree(dataSet[:-1],labels) 

result=trees.prediction(myTree,dataSet[-1][:-1],labels)


如果觉得我的文章对您有用,请随意打赏。您的支持将鼓励我继续创作!

¥ 打赏支持
89人赞 举报
分享到
用户评价(0)

暂无评价,你也可以发布评价哦:)

扫码APP

扫描使用APP

扫码使用

扫描使用小程序