Skip to content

Commit 35e56e9

Browse files
committedJul 5, 2018
决策树算法
1 parent 3202a0c commit 35e56e9

13 files changed

+352
-0
lines changed
 

‎ch02-Decision-Tree/Cal_Entropy.py

+30
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,30 @@
1+
# -*- coding: utf-8 -*-
2+
"""
3+
Created on Wed Jul 4 13:06:48 2018
4+
5+
@author: Administrator
6+
"""
7+
8+
from math import log
9+
from numpy import *
10+
11+
def cal_entropy(data):
12+
'''计算样本实例的熵'''
13+
entries_num = len(data)
14+
label_count = {} #字典存储每个类别出现的次数
15+
16+
for vec in data:
17+
cur_label = vec[-1]
18+
# 将样本标签提取出来,并计数
19+
label_count[cur_label] = label_count.get(cur_label,0) + 1
20+
Entropy = 0.0
21+
# 对每一个类别,计算样本中取到该类的概率
22+
# 最后将概率带入,求出熵
23+
for key in label_count:
24+
prob = float(label_count[key]) / entries_num
25+
Entropy += prob * math.log(prob, 2) #此处使用numpy.math
26+
return (0-Entropy)
27+
28+
29+
30+

‎ch02-Decision-Tree/Classify_tree.py

+28
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,28 @@
1+
# -*- coding: utf-8 -*-
2+
"""
3+
Created on Wed Jul 4 21:44:42 2018
4+
5+
@author: Administrator
6+
"""
7+
8+
def classify(inp_tree, labels, test_vec):
9+
first_node = list(inp_tree.keys())[0]
10+
second_dict = inp_tree[first_node]
11+
index = labels.index(first_node)
12+
13+
for key in second_dict.keys():
14+
if test_vec[index] == key:
15+
if type(second_dict[key]).__name__ == 'dict':
16+
class_label = classify(second_dict[key], labels, test_vec)
17+
else: class_label = second_dict[key]
18+
return class_label
19+
20+
def store_tree(inp_tree, filename):
21+
import pickle
22+
with open(filename,'w') as fp:
23+
pickle.dump(inp_tree, fp)
24+
25+
def grab_tree(filename):
26+
import pickle
27+
fr = open(filename)
28+
return pickle.load(fr)

‎ch02-Decision-Tree/Decision_Tree.py

+58
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,58 @@
1+
# -*- coding: utf-8 -*-
2+
"""
3+
Created on Wed Jul 4 16:42:04 2018
4+
5+
@author: Administrator
6+
"""
7+
import operator
8+
from Split_by_entropy import *
9+
10+
def Majority_vote(classList):
11+
'''
12+
使用多数表决法:若集合中属于第K类的节点最多,则此分支集合
13+
划分为第K类
14+
'''
15+
classcount = {}
16+
for vote in classList:
17+
classcount[vote] = classcount.get(vote,0) + 1
18+
sorted_count = sorted(classcount.items(), key = operator.itemgetter(1),\
19+
reverse = True)
20+
# 获取每一类出现的节点数(没出现默认为0)并进行排序
21+
# 返回最大项的KEY所对应的类别
22+
return sorted_count[0][0]
23+
24+
def Create_Tree(dataset,labels):
25+
26+
classList = [x[-1] for x in dataset]
27+
if classList.count(classList[0]) == len(classList):
28+
return classList[0]
29+
#
30+
if len(dataset[0]) == 1:
31+
return Majority_vote(classList)
32+
33+
best_feature = Split_by_entropy(dataset)
34+
best_labels = labels[best_feature]
35+
36+
myTree = {best_labels:{}}
37+
# 此位置书上写的有误,书上为del(labels[bestFeat])
38+
# 相当于操作原始列表内容,导致原始列表内容发生改变
39+
# 按此运行程序,报错'no surfacing'is not in list
40+
# 以下代码已改正
41+
42+
# 复制当前特征标签列表,防止改变原始列表的内容
43+
subLabels=labels[:]
44+
# 删除属性列表中当前分类数据集特征
45+
del(subLabels[best_feature])
46+
47+
# 使用列表推导式生成该特征对应的列
48+
f_val = [x[best_feature] for x in dataset]
49+
uni_val = set(f_val)
50+
for value in uni_val:
51+
# 递归创建子树并返回
52+
myTree[best_labels][value] = Create_Tree(Split_Data(dataset\
53+
,best_feature,value), subLabels)
54+
55+
return myTree
56+
57+
58+

‎ch02-Decision-Tree/Plot_tree.py

+107
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,107 @@
1+
# -*- coding: utf-8 -*-
2+
"""
3+
Created on Wed Jul 4 21:10:10 2018
4+
5+
@author: Administrator
6+
"""
7+
8+
import matplotlib.pyplot as plt
9+
10+
# 定义文本框和箭头格式
11+
decisionNode = dict(boxstyle="sawtooth", fc="0.8")
12+
leafNode = dict(boxstyle="round4", fc="0.8")
13+
arrow_args = dict(arrowstyle="<-")
14+
15+
# 绘制带箭头的注释
16+
def plotNode(nodeTxt, centerPt, parentPt, nodeType):
17+
createPlot.ax1.annotate(nodeTxt, xy=parentPt, xycoords='axes fraction',
18+
xytext=centerPt, textcoords='axes fraction',
19+
va="center", ha="center", bbox=nodeType, arrowprops=arrow_args )
20+
21+
22+
def Num_of_leaf(myTree):
23+
'''计算此树的叶子节点数目'''
24+
num_leaf = 0
25+
first_node = myTree.keys()
26+
first_node = list(first_node)[0]
27+
second_dict = myTree[first_node]
28+
# Python3中使用LIST转换firstnode,原书使用[0]直接索引只能用于Python2
29+
# 对于树,每次判断value是否为字典,若为字典则进行递归,否则累加器+1
30+
for key in second_dict.keys():
31+
if type(second_dict[key]).__name__ =='dict':
32+
num_leaf += Num_of_leaf(second_dict[key])
33+
else: num_leaf += 1
34+
return num_leaf
35+
36+
def Depth_of_tree(myTree):
37+
'''计算此树的总深度'''
38+
depth = 0
39+
first_node = myTree.keys()
40+
first_node = list(first_node)[0]
41+
second_dict = myTree[first_node]
42+
43+
for key in second_dict.keys():
44+
if type(second_dict[key]).__name__ =='dict':
45+
pri_depth = 1 + Depth_of_tree(second_dict[key])
46+
else: pri_depth = 1
47+
# 对于树,每次判断value是否为字典,若为字典则进行递归,否则计数器+1
48+
if pri_depth > depth: depth = pri_depth
49+
return depth
50+
51+
def retrieveTree(i):
52+
'''
53+
保存了树的测试数据
54+
'''
55+
listOfTrees =[{'no surfacing': {0: 'no', 1: {'flippers': {0: 'no', \
56+
1: 'yes'}}}},{'no surfacing': {0: 'no', \
57+
1: {'flippers': {0: {'head': {0: 'no', 1: 'yes'}}, 1: 'no'}}}}
58+
]
59+
return listOfTrees[i]
60+
61+
def plotmidtext(cntrpt,parentpt,txtstring):
62+
'''作用是计算tree的中间位置
63+
cntrpt起始位置,parentpt终止位置,txtstring:文本标签信息
64+
'''
65+
xmid=(parentpt[0]-cntrpt[0])/2.0+cntrpt[0]
66+
# cntrPt 起点坐标 子节点坐标
67+
# parentPt 结束坐标 父节点坐标
68+
ymid=(parentpt[1]-cntrpt[1])/2.0+cntrpt[1] # 找到x和y的中间位置
69+
createPlot.ax1.text(xmid,ymid,txtstring)
70+
71+
72+
def plottree(mytree,parentpt,nodetxt):
73+
numleafs=Num_of_leaf(mytree)
74+
depth=Depth_of_tree(mytree)
75+
firststr=list(mytree.keys())[0]
76+
cntrpt=(plottree.xoff+(1.0+float(numleafs))/2.0/plottree.totalw,plottree.yoff)
77+
# 计算子节点的坐标
78+
plotmidtext(cntrpt,parentpt,nodetxt) #绘制线上的文字
79+
plotNode(firststr,cntrpt,parentpt,decisionNode)#绘制节点
80+
seconddict=mytree[firststr]
81+
plottree.yoff=plottree.yoff-1.0/plottree.totald
82+
# 每绘制一次图,将y的坐标减少1.0/plottree.totald,间接保证y坐标上深度的
83+
for key in seconddict.keys():
84+
if type(seconddict[key]).__name__=='dict':
85+
plottree(seconddict[key],cntrpt,str(key))
86+
else:
87+
plottree.xoff=plottree.xoff+1.0/plottree.totalw
88+
plotNode(seconddict[key],(plottree.xoff,plottree.yoff),cntrpt,leafNode)
89+
plotmidtext((plottree.xoff,plottree.yoff),cntrpt,str(key))
90+
plottree.yoff=plottree.yoff+1.0/plottree.totald
91+
92+
93+
def createPlot(intree):
94+
# 类似于Matlab的figure,定义一个画布(暂且这么称呼吧),背景为白色
95+
fig=plt.figure(1,facecolor='white')
96+
fig.clf() # 把画布清空
97+
axprops=dict(xticks=[],yticks=[])
98+
# createPlot.ax1为全局变量,绘制图像的句柄,subplot为定义了一个绘图,
99+
# 111表示figure中的图有1行1列,即1个,最后的1代表第一个图
100+
# frameon表示是否绘制坐标轴矩形
101+
createPlot.ax1=plt.subplot(111,frameon=False,**axprops)
102+
103+
plottree.totalw=float(Num_of_leaf(intree))
104+
plottree.totald=float(Depth_of_tree(intree))
105+
plottree.xoff=-0.6/plottree.totalw;plottree.yoff=1.2;
106+
plottree(intree,(0.5,1.0),'')
107+
plt.show()
+58
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,58 @@
1+
# -*- coding: utf-8 -*-
2+
"""
3+
Created on Wed Jul 4 13:35:15 2018
4+
5+
@author: Administrator
6+
"""
7+
8+
from Cal_Entropy import *
9+
10+
def Split_Data(dataset, axis, value):
11+
'''
12+
使用传入的axis以及value划分数据集
13+
axis代表在每个列表中的第X位,value为用来划分的特征值
14+
'''
15+
new_subset = []
16+
# 利用循环将不符合value的特征值划分入另一集合
17+
# 相当于将value单独提取出来(或作为叶节点)
18+
for vec in dataset:
19+
if vec[axis] == value:
20+
feature_split = vec[:axis]
21+
feature_split.extend(vec[axis + 1:])
22+
new_subset.append(feature_split)
23+
# extend将VEC中的元素一一纳入feature_split
24+
# append则将feature_split作为列表结合进目标集合
25+
26+
return new_subset
27+
28+
def Split_by_entropy(dataset):
29+
'''
30+
使用熵原则进行数据集划分
31+
@信息增益:info_gain = old -new
32+
@最优特征:best_feature
33+
@类别集合:uniVal
34+
'''
35+
feature_num = len(dataset[0]) - 1
36+
ent_old = cal_entropy(dataset)
37+
best_gain = 0.0
38+
best_feature = -1
39+
# ENT_OLD代表划分前集合的熵,ENT_NEW代表划分后的熵
40+
# best_gain将在迭代每一次特征的时候更新,最终选出最优特征
41+
for i in range(feature_num):
42+
feature_list = [x[i] for x in dataset]
43+
uniVal = set(feature_list)
44+
ent_new = 0.0
45+
# 使用set剔除重复项,保留该特征对应的不同取值
46+
for value in uniVal:
47+
sub_set = Split_Data(dataset, i, value)
48+
prob = len(sub_set) / float(len(dataset))
49+
# 使用熵计算函数求出划分后的熵值
50+
ent_new += prob * (0 - cal_entropy(sub_set))
51+
52+
# 由ent_old - ent_new选出划分对应的最优特征
53+
Info_gain = ent_old - ent_new
54+
if(Info_gain > best_gain):
55+
best_gain = Info_gain
56+
best_feature = i
57+
58+
return best_feature

‎ch02-Decision-Tree/__init__.py

+1
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
#

‎ch02-Decision-Tree/__main__.py

+46
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,46 @@
1+
# -*- coding: utf-8 -*-
2+
"""
3+
Created on Wed Jul 4 13:23:05 2018
4+
5+
@author: Administrator
6+
"""
7+
8+
from numpy import *
9+
from Cal_Entropy import *
10+
from Split_by_entropy import *
11+
from Decision_Tree import *
12+
from Plot_tree import *
13+
from Classify_tree import *
14+
15+
def create_data():
16+
dataSet = [[1,1,'yes'],
17+
[1,1,'yes'],
18+
[1,0,'no'],
19+
[0,1,'no'],
20+
[0,1,'no']]
21+
labels = ['no surfacing', 'flippers']
22+
return dataSet, labels
23+
24+
if __name__ == '__main__':
25+
myData, labels = create_data()
26+
print(myData)
27+
print(cal_entropy(myData))
28+
29+
print(Split_Data(myData,0,1))
30+
print(Split_by_entropy(myData))
31+
32+
mytree = Create_Tree(myData, labels)
33+
print(mytree)
34+
35+
myTree = retrieveTree(0)
36+
print(Num_of_leaf(myTree), Depth_of_tree(myTree))
37+
myTree['no surfacing'][3] = 'maybe'
38+
createPlot(myTree)
39+
40+
with open('lenses.txt') as fp:
41+
lenses = [line.strip().split('\t') for line in fp.readlines()]
42+
lensesLabels=['age','prescript','astigmatic','tearRate']
43+
44+
lense_Tree = Create_Tree(lenses, lensesLabels)
45+
#createPlot(lense_Tree)
46+
print(classify(lense_Tree, lensesLabels, ['young','hyper','yes','reducedno']))
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.

‎ch02-Decision-Tree/lenses.txt

+24
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,24 @@
1+
young myope no reduced no lenses
2+
young myope no normal soft
3+
young myope yes reduced no lenses
4+
young myope yes normal hard
5+
young hyper no reduced no lenses
6+
young hyper no normal soft
7+
young hyper yes reduced no lenses
8+
young hyper yes normal hard
9+
pre myope no reduced no lenses
10+
pre myope no normal soft
11+
pre myope yes reduced no lenses
12+
pre myope yes normal hard
13+
pre hyper no reduced no lenses
14+
pre hyper no normal soft
15+
pre hyper yes reduced no lenses
16+
pre hyper yes normal no lenses
17+
presbyopic myope no reduced no lenses
18+
presbyopic myope no normal no lenses
19+
presbyopic myope yes reduced no lenses
20+
presbyopic myope yes normal hard
21+
presbyopic hyper no reduced no lenses
22+
presbyopic hyper no normal soft
23+
presbyopic hyper yes reduced no lenses
24+
presbyopic hyper yes normal no lenses

0 commit comments

Comments
 (0)
Please sign in to comment.