脚本专栏 发布日期:2025/11/13 浏览次数:1
决策树之ID3算法及其Python实现,具体内容如下
主要内容
决策树背景知识
决策树一般构建过程
ID3算法分裂属性的选择
ID3算法流程及其优缺点分析
ID3算法Python代码实现
1. 决策树背景知识
"color: #800000">注:分裂属性的选取是决策树生产过程中的关键,它决定了生成的决策树的性能、结构。分裂属性选择的评判标准是决策树算法之间的根本区别。
3. ID3算法分裂属性的选择——信息增益
"htmlcode">
# -*- coding: utf-8 -*-
__author__ = 'zhihua_oba'
import operator
from numpy import *
from math import log
#文件读取
def file2matrix(filename, attribute_num): #传入参数:文件名,属性个数
fr = open(filename)
arrayOLines = fr.readlines()
numberOfLines = len(arrayOLines) #统计数据集行数(样本个数)
dataMat = zeros((numberOfLines, attribute_num))
classLabelVector = [] #分类标签
index = 0
for line in arrayOLines:
line = line.strip() #strip() 删除字符串中的'\n'
listFromLine = line.split() #将一个字符串分裂成多个字符串组成的列表,不带参数时以空格进行分割,当代参数时,以该参数进行分割
dataMat[index, : ] = listFromLine[0:attribute_num] #读取数据对象属性值
classLabelVector.append(listFromLine[-1]) #读取分类信息
index += 1
dataSet = [] #数组转化成列表
index = 0
for index in range(0, numberOfLines):
temp = list(dataMat[index, :])
temp.append(classLabelVector[index])
dataSet.append(temp)
return dataSet
#划分数据集
def splitDataSet(dataSet, axis, value):
retDataSet = []
for featvec in dataSet: #每行
if featvec[axis] == value: #每行中第axis个元素和value相等 #删除对应的元素,并将此行,加入到rerDataSet
reducedFeatVec = featvec[:axis]
reducedFeatVec.extend(featvec[axis+1:])
retDataSet.append(reducedFeatVec)
return retDataSet
#计算香农熵 #计算数据集的香农熵 == 计算数据集类标签的香农熵
def calcShannonEnt(dataSet):
numEntries = len(dataSet) #数据集样本点个数
labelCounts = {} #类标签
for featVec in dataSet: #统计数据集类标签的个数,字典形式
currentLabel = featVec[-1]
if currentLabel not in labelCounts.keys():
labelCounts[currentLabel] = 0
labelCounts[currentLabel] += 1
shannonEnt = 0.0
for key in labelCounts:
prob = float(labelCounts[key])/numEntries
shannonEnt -= prob * log(prob, 2)
return shannonEnt
#根据香农熵,选择最优的划分方式 #根据某一属性划分后,类标签香农熵越低,效果越好
def chooseBestFeatureToSplit(dataSet):
baseEntropy = calcShannonEnt(dataSet) #计算数据集的香农熵
numFeatures = len(dataSet[0])-1
bestInfoGain = 0.0 #最大信息增益
bestFeature = 0 #最优特征
for i in range(0, numFeatures):
featList = [example[i] for example in dataSet] #所有子列表(每行)的第i个元素,组成一个新的列表
uniqueVals = set(featList)
newEntorpy = 0.0
for value in uniqueVals: #数据集根据第i个属性进行划分,计算划分后数据集的香农熵
subDataSet = splitDataSet(dataSet, i, value)
prob = len(subDataSet)/float(len(dataSet))
newEntorpy += prob*calcShannonEnt(subDataSet)
infoGain = baseEntropy-newEntorpy #划分后的数据集,香农熵越小越好,即信息增益越大越好
if(infoGain > bestInfoGain):
bestInfoGain = infoGain
bestFeature = i
return bestFeature
#如果数据集已经处理了所有属性,但叶子结点中类标签依然不是唯一的,此时需要决定如何定义该叶子结点。这种情况下,采用多数表决方法,对该叶子结点进行分类
def majorityCnt(classList): #传入参数:叶子结点中的类标签
classCount = {}
for vote in classList:
if vote not in classCount.keys():
classCount[vote] = 0
classCount[vote] += 1
sortedClassCount = sorted(classCount.iteritems(), key=operator.itemgetter(1), reverse=True)
return sortedClassCount[0][0]
#创建树
def createTree(dataSet, labels): #传入参数:数据集,属性标签(属性标签作用:在输出结果时,决策树的构建更加清晰)
classList = [example[-1] for example in dataSet] #数据集样本的类标签
if classList.count(classList[0]) == len(classList): #如果数据集样本属于同一类,说明该叶子结点划分完毕
return classList[0]
if len(dataSet[0]) == 1: #如果数据集样本只有一列(该列是类标签),说明所有属性都划分完毕,则根据多数表决方法,对该叶子结点进行分类
return majorityCnt(classList)
bestFeat = chooseBestFeatureToSplit(dataSet) #根据香农熵,选择最优的划分方式
bestFeatLabel = labels[bestFeat] #记录该属性标签
myTree = {bestFeatLabel:{}} #树
del(labels[bestFeat]) #在属性标签中删除该属性
#根据最优属性构建树
featValues = [example[bestFeat] for example in dataSet]
uniqueVals = set(featValues)
for value in uniqueVals:
subLabels = labels[:]
subDataSet = splitDataSet(dataSet, bestFeat, value)
myTree[bestFeatLabel][value] = createTree(subDataSet, subLabels)
return myTree
#测试算法:使用决策树,对待分类样本进行分类
def classify(inputTree, featLabels, testVec): #传入参数:决策树,属性标签,待分类样本
firstStr = inputTree.keys()[0] #树根代表的属性
secondDict = inputTree[firstStr]
featIndex = featLabels.index(firstStr) #树根代表的属性,所在属性标签中的位置,即第几个属性
for key in secondDict.keys():
if testVec[featIndex] == key:
if type(secondDict[key]).__name__ == 'dict':
classLabel = classify(secondDict[key], featLabels, testVec)
else:
classLabel = secondDict[key]
return classLabel
def main():
dataSet = file2matrix('test_sample.txt', 4)
labels = ['attr01', 'attr02', 'attr03', 'attr04']
labelsForCreateTree = labels[:]
Tree = createTree(dataSet, labelsForCreateTree )
testvec = [2, 3, 2, 3]
print classify(Tree, labels, testvec)
if __name__ == '__main__':
main()
以上就是本文的全部内容,希望对大家的学习有所帮助,也希望大家多多支持。