决策树实现


决策树构造

from math import log
def createDataSet():
    dataSet = [[1,1,'yes'],[1,1,'yes'],[1,0,'no'],[0,1,'no'],[0,1,'no']]
    labels = ['no surfacing','flippers']
    return dataSet, labels
```python myData, labels = createDataSet() ```

信息增益

基于信息熵

def calcShannonEnt(dataSet):
    numEntries = len(dataSet)
    labelCounts = {}
    for featVec in dataSet:
        currentLabel = featVec[-1]
        if currentLabel not in labelCounts.keys():
            labelCounts[currentLabel] = 0
        labelCounts[currentLabel] += 1
        shannonEnt = 0.0
    for key in labelCounts:
        prob = float(labelCounts[key]) / numEntries
        shannonEnt -= prob * log(prob, 2)
    return shannonEnt
calcShannonEnt(myData)
0.9709505944546686

增加第三个名为maybe的分类,看看熵有啥变化

myData[0][-1] = 'maybe'
calcShannonEnt(myData)
1.3709505944546687

另一种度量集合无序程度的方法是基尼指数(Gini impurity),基尼指数反映了从样本集中随机抽两个样本类别标记不一致的概率

划分数据集

对所有特征计算一次信息熵,然后选择最优的信息增益的那个特征。下面按照给定特征划分数据集,输入待划分的数据集,特征和需要返回的特征的值

def splitDataSet(dataSet, axis, value):
    retDataSet = []
    for featureVec in dataSet:
        if featureVec[axis] == value:
            reducedFeatVec = [featureVec[i] for i in range(len(featureVec)) if i != axis]
            retDataSet.append(reducedFeatVec)
    return retDataSet
myData, labels = createDataSet()
splitDataSet(myData, 0, 1)
[[1, 'yes'], [1, 'yes'], [0, 'no']]

选择最大的信息增益特征

def bestFeatSplit(dataSet):
    numFeatures = len(dataSet[0]) - 1
    entropy = calcShannonEnt(dataSet)
    bestInfoGain = 0
    bestFeature = 0
    for i in range(numFeatures):
        featList = [data[i] for data in dataSet]
        uniqueVals = set(featList)     #特征的取值
        newEntropy = 0
        #计算该特征下的条件熵
        for value in uniqueVals:
            subDataSet = splitDataSet(dataSet, i, value)
            prob = len(subDataSet) / len(dataSet)
            newEntropy += prob * calcShannonEnt(subDataSet)
        infoGain = entropy - newEntropy
        if (infoGain > bestInfoGain):
            bestInfoGain = infoGain
            bestFeature = i
    return bestFeature
bestFeature = bestFeatSplit(myData)
bestFeature
0

上述划分结果与直观一致,第一个特征为1的分组有两个是yes,一个是no,特征为0的只有一个,是no,这个划分结果蛮好的

递归构建决策树

递归结束的条件是:程序遍历完所有的划分数据集的属性,或者每个分枝下的所有实例都具有相同的分类

我们可以设置算法可以划分的最大分组数目。如果数据集已经处理了所有特征,但是类标签仍不是唯一的,那么就需要使用多数表决法,将较多的那个分类赋给该节点

def maxClassCount(classList):
    classCount = {}
    for lclass in classList:
        if lclass not in classCount.keys():
            classCount[lclass] = 0
        classCount[lclass] += 1
    maxKeyCount = 0
    maxCountKey = None
    #找到最多个数的类别
    for key in classCount.keys():
        if classCount[key] > maxKeyCount:
            maxKeyCount = classCount[key]
            maxCountKey = key
    return maxCountKey
maxClassCount(['a','a','b'])
'a'

递归创建决策树

import copy
def createTree(dataSet, labels):
    classList = [data[-1] for data in dataSet]
    #若类别完全相同则停止划分
    if classList.count(classList[0]) == len(classList):
        return classList[0]
    #遍历完所有特征时返回出现次数最多的类别
    if len(dataSet[0]) == 1:
        return maxClassCount(classList)
    bestFeat = bestFeatSplit(dataSet)
    bestFeatLabel = labels[bestFeat]
    myTree = {bestFeatLabel:{}}
    del(labels[bestFeat])
    featVals = [data[bestFeat] for data in dataSet]
    uniqueVals = set(featVals)
    for value in uniqueVals:
        #复制类标签,为了防止原列表被改变
        subLabels = copy.deepcopy(labels)
        myTree[bestFeatLabel][value] = createTree(splitDataSet(dataSet, bestFeat, value), subLabels)
    return myTree
myTree = createTree(myData, labels)
myTree
{'no surfacing': {0: 'no', 1: {'flippers': {0: 'no', 1: 'yes'}}}}

决策树可视化

使用matplotlib库创建树形图

使用文本注解绘制树节点

import matplotlib.pyplot as plt
%matplotlib notebook
#定义文本框和箭头格式
decisionNode = dict(boxstyle="sawtooth", fc='0.8')
leafNode = dict(boxstyle='round4', fc='0.8')
arrow_args = dict(arrowstyle='<-')
def plotNode(nodeTxt, centerPt, parentPt, nodeType):
    ax1.annotate(nodeTxt, xy=parentPt, xycoords='axes fraction', xytext=centerPt, textcoords='axes fraction', 
                   va='center', ha='center', bbox=nodeType, arrowprops=arrow_args)

def createPlot():
    global ax1
    fig = plt.figure()
    #clear the figure
    fig.clf()
    ax1 = plt.subplot(111,frameon=False)
    plotNode('decision node', (0.5, 0.1), (0.1, 0.5), decisionNode)
    plotNode('leaf node',(0.8, 0.1), (0.3, 0.8), leafNode)
    plt.show()
createPlot()
<IPython.core.display.Javascript object>

构造注解树

如何放置所有节点?这需要知道有多少个节点以便可以确定x轴的长度,还要知道树有多少层,以便可以确定y轴的高度。定义两个新函数:getNumLeafs()和getTreeDepth(),来获取叶节点的数目和树的层数

def getNumLeafs(myTree):
    numLeafs = 0
    firstStr = list(myTree.keys())[0]  #获取第一个字符
    secondDict = myTree[firstStr]
    for key in secondDict.keys():
        if type(secondDict[key]).__name__ == 'dict':
            #递归调用getNumLeafs
            numLeafs += getNumLeafs(secondDict[key])
        else:
            numLeafs += 1
    return numLeafs
numLeafs = getNumLeafs(myTree)
numLeafs
3
def getTreeDepth(myTree):
    maxDepth = 0
    firstStr = list(myTree.keys())[0]
    secondDict = myTree[firstStr]
    for key in secondDict.keys():
        if type(secondDict[key]).__name__ == 'dict':
            thisDepth = 1 + getTreeDepth(secondDict[key])
        else:
            thisDepth = 1
        #与获取节点个数不同,这里需要获得最大而不是总数
        if thisDepth > maxDepth:
            maxDepth = thisDepth
    return maxDepth
maxDepth = getTreeDepth(myTree)
maxDepth
2

下面就可以开始画一棵完整的树了

def plotMidText(centerPt, parentPt, txt):
    #在父子节点中写上文本信息
    xMid = (parentPt[0] + centerPt[0]) / 2
    yMid = (parentPt[1] + centerPt[1]) / 2
    ax1.text(xMid, yMid, txt)

#重新定义createPlot()
def createPlot(myTree):
    global ax1
    global totalW,totalD,xoff,yoff
    fig = plt.figure()
    fig.clf()
    ax1 = plt.subplot(111,frameon=False)
    totalW = getNumLeafs(myTree)
    totalD = getTreeDepth(myTree)
    xoff = -0.5 / totalW
    yoff = 1 
    plotTree(myTree, (0.5, 1.0), '')
    plt.show()
def plotTree(tree, parentPt, nodeTxt):
    global ax1
    global totalW,totalD,xoff,yoff
    numLeafs = getNumLeafs(tree)
    depth = getTreeDepth(tree)
    firstStr = list(tree.keys())[0]
    centerPt = (xoff + (1 + numLeafs)/(2*totalW), yoff)
    plotMidText(centerPt, parentPt, nodeTxt)   #画text
    plotNode(firstStr, centerPt, parentPt, decisionNode)    #画节点指向图
    secondDict = tree[firstStr]
    yoff = yoff - 1/totalD
    for key in secondDict.keys():
        if type(secondDict[key]).__name__ == 'dict':
            #递归调用
            plotTree(secondDict[key], centerPt, str(key))
            #将yoff恢复
            yoff = yoff + 1/totalD
        else:
            xoff += 1/totalW
            plotNode(secondDict[key], (xoff, yoff), centerPt, leafNode)
            plotMidText((xoff, yoff), centerPt, str(key))
createPlot({'no surfacing':{0:'no',1:{'flippers':{0:'no',1:{'foot':{0:'no',1:'yes'}}}},3:{'eyes':{0:'no',1:'yes'}}}})
<IPython.core.display.Javascript object>

测试和存储

使用决策树进行分类

def classify(myTree, featLabels, testVec):
    firstStr = list(myTree.keys())[0]
    secondDict = myTree[firstStr]
    #将标签字符串转化为索引
    featIndex = featLabels.index(firstStr)
    for key in secondDict.keys():
        if testVec[featIndex] == key:
            if type(secondDict[key]).__name__ == 'dict':
                classLabel = classify(secondDict[key], featLabels, testVec)
            else:
                classLabel = secondDict[key]
    return classLabel
classify(myTree, ['no surfacing','flippers'], [1,0])
'no'
classify(myTree, ['no surfacing','flippers'], [1,1])
'yes'

决策树的存储,用pickle模块序列化决策树,序列化对象可以在磁盘上保存下来,并在需要的时候读取

def storeTree(tree, filename):
    import pickle
    f = open(filename, 'wb')
    pickle.dump(tree, f)
    f.close()

def getTree(filename):
    import pickle
    f = open(filename,'rb')
    return pickle.load(f)
storeTree(myTree, 'myDecisionTree.txt')
getTree('myDecisionTree.txt')
{'no surfacing': {0: 'no', 1: {'flippers': {0: 'no', 1: 'yes'}}}}

使用决策树预测隐形眼镜类型

f = open('lenses.txt','r')
lenses = [data.strip().split('\t') for data in f.readlines()]
lenses
[['young', 'myope', 'no', 'reduced', 'no lenses'],
 ['young', 'myope', 'no', 'normal', 'soft'],
 ['young', 'myope', 'yes', 'reduced', 'no lenses'],
 ['young', 'myope', 'yes', 'normal', 'hard'],
 ['young', 'hyper', 'no', 'reduced', 'no lenses'],
 ['young', 'hyper', 'no', 'normal', 'soft'],
 ['young', 'hyper', 'yes', 'reduced', 'no lenses'],
 ['young', 'hyper', 'yes', 'normal', 'hard'],
 ['pre', 'myope', 'no', 'reduced', 'no lenses'],
 ['pre', 'myope', 'no', 'normal', 'soft'],
 ['pre', 'myope', 'yes', 'reduced', 'no lenses'],
 ['pre', 'myope', 'yes', 'normal', 'hard'],
 ['pre', 'hyper', 'no', 'reduced', 'no lenses'],
 ['pre', 'hyper', 'no', 'normal', 'soft'],
 ['pre', 'hyper', 'yes', 'reduced', 'no lenses'],
 ['pre', 'hyper', 'yes', 'normal', 'no lenses'],
 ['presbyopic', 'myope', 'no', 'reduced', 'no lenses'],
 ['presbyopic', 'myope', 'no', 'normal', 'no lenses'],
 ['presbyopic', 'myope', 'yes', 'reduced', 'no lenses'],
 ['presbyopic', 'myope', 'yes', 'normal', 'hard'],
 ['presbyopic', 'hyper', 'no', 'reduced', 'no lenses'],
 ['presbyopic', 'hyper', 'no', 'normal', 'soft'],
 ['presbyopic', 'hyper', 'yes', 'reduced', 'no lenses'],
 ['presbyopic', 'hyper', 'yes', 'normal', 'no lenses']]
'young    myope    no    reduced    no lenses'.strip()
'young\tmyope\tno\treduced\tno lenses'
lensesLabels = ['age', 'prescript', 'astigmatic', 'tearRate']
lensesTree = createTree(lenses, lensesLabels)
lensesTree
{'tearRate': {'reduced': 'no lenses',
  'normal': {'astigmatic': {'yes': {'prescript': {'myope': 'hard',
      'hyper': {'age': {'presbyopic': 'no lenses',
        'pre': 'no lenses',
        'young': 'hard'}}}},
    'no': {'age': {'presbyopic': {'prescript': {'myope': 'no lenses',
        'hyper': 'soft'}},
      'pre': 'soft',
      'young': 'soft'}}}}}}
createPlot(lensesTree)
<IPython.core.display.Javascript object>

讲道理这个树太密了,可能会导致过拟合,需要剪枝,剪枝放在后面实现


文章作者: lovelyfrog
版权声明: 本博客所有文章除特別声明外,均采用 CC BY 4.0 许可协议。转载请注明来源 lovelyfrog !
 上一篇
Logistic回归 Logistic回归
梯度上升法确定最佳回归系数梯度上升法用来求函数的最大值,而梯度下降法用来求函数的最小值,其实就是一个东西换了个说法而已 def loadDataSet(): dataMat = [] labelMat = [] wi
2018-04-27
下一篇 
理解和可视化CNN 理解和可视化CNN
本文总结了 cs231n lecture 13的知识点,介绍了一些理解 CNN 的方法。从激活角度来看,有最近邻,降维,最大patches,遮盖。从梯度角度看,有显著图,特征转换,对抗图片。还有一些有趣的实践:DeepDream,风格迁
2018-04-20
  目录