加载中…
个人资料
  • 博客等级:
  • 博客积分:
  • 博客访问:
  • 关注人气:
  • 获赠金笔:0支
  • 赠出金笔:0支
  • 荣誉徽章:
正文 字体大小:

决策树ID3算法及python实现

(2017-06-15 21:33:11)
标签:

it

一、熵
      熵是信息论中表示随机变量不确定的概念,它的计算方法是:不连续分布中每一个值的概率与概率对数的乘积的和的负数,熵越大,不确定性越大。
二、信息增益与分类变量的选择
      信息增益是指根据某特征分类后的数据,其熵与特征占比乘积的和与原数据的信息熵的差值。信息增益越大选择该特征进行分类的效果越好。
三、决策树算法实施步骤
      ID3算法步骤参考李航统计学习
     

四、决策树python代码实现
      (1) R语言决策树实现
      (2)SK-LEARN实现
      (3)编制python程序

class Decision_Tree:

    #def __init__(self,dataSet,labels):
        #self.dataSet=array(dataSet)
        #self.labels=labels
        #self.tree={}

    def cacEntropy(self,dataSet):
        m,n=dataSet.shape
        labelCount={}
        for row in dataSet:
            if row[-1] in labelCount.keys():
                labelCount[row[-1]]+=1
            else:
                labelCount[row[-1]]=1
        entropy=0
        for key in labelCount.keys():
            prob=float(labelCount[key])/float(m) #注意除法浮点数据转化
            entropy+=-prob*log(prob,2)
        return entropy

    def bestFeature(self,dataSet):
        m,n=dataSet.shape
        entropy=self.cacEntropy(dataSet)
        best=0;DEntropy=-1
        for i in range(n-1):
            values=[value[i] for value in dataSet]
            uniqueValue=set(values)
            conEntropy=0
            for value in uniqueValue:
                subData=self.splitData(dataSet,i,value)
                subm,subn=subData.shape
                conEntropy+=float(subn)/float(n)*self.cacEntropy(subData)
                if entropy-conEntropy>DEntropy:
                    DEntropy=entropy-conEntropy
                    best=i
        return best
            

    def splitData(self,dataSet,axis,value):
        dataSet=list(dataSet)
        dataSplit=[]
        for rowVector in dataSet:
            if rowVector[axis]==value:
                tempData=list(rowVector[:axis])
                tempData.extend(list(rowVector[axis+1:]))
                dataSplit.append(tempData)
        return array(dataSplit)

    def sortLast(classList):
        classCount={}
        for vote in classList:
            if vote not in classCount.keys():
                classCount[vote] = 0
            classCount[vote] += 1
        sortedClassCount = sorted(classCount.iteritems(),
        key=operator.itemgetter(1), reverse=True)
        return sortedClassCount[0][0]

    def createTree(self,dataSet,labels):
        #dataSet=self.dataSet
        m,n=dataSet.shape
        #print dataSet
        #labels=self.labels
        if len(set(dataSet[:,-1]))==1:  #所有的都是同类
            return dataSet[0,-1]
        if n==1:                        #只剩最后最后一类
            return sortLast(dataSet[:,-1])
        i=self.bestFeature(dataSet)
        bestLabel=labels[i]
        tree={labels[i]:{}}
        del labels[i]
        values=[value[i] for value in dataSet]
        uniqueValue=set(values)        #集合set为了取值不重复
        for value in uniqueValue:
            subData=self.splitData(dataSet,i,value)
            subLabels = labels[:]  #为了LABEL在循环时重复用
            tree[bestLabel][value]=self.createTree(subData,labels)
        return tree

0

阅读 收藏 喜欢 打印举报/Report
  

新浪BLOG意见反馈留言板 欢迎批评指正

新浪简介 | About Sina | 广告服务 | 联系我们 | 招聘信息 | 网站律师 | SINA English | 产品答疑

新浪公司 版权所有