之前有文章介绍过决策树(ID3)。简单回顾一下:ID3每次选取最佳特征来分割数据,这个最佳特征的判断原则是通过信息增益来实现的。按照某种特征切分数据后,该特征在以后切分数据集时就不再使用,因此存在切分过于迅速的问题。ID3算法还不能处理连续性特征。
下面简单介绍一下其他算法:
CART是Classification And Regerssion Trees的缩写,既能处理分类任务也能做回归任务。
CART树的典型代表时二叉树,根据不同的条件将分类。
CART树构建算法
与ID3决策树的构建方法类似,直接给出CART树的构建过程。首先与ID3类似采用字典树的数据结构,包含以下4中元素:
过程如下:
明显的递归算法。
通过数据过滤的方式分割数据集,返回两个子集。
def splitDatas(rows, value, column):
# 根据条件分离数据集(splitDatas by value, column)
# return 2 part(list1, list2)
list1 = []
list2 = []
if isinstance(value, int) or isinstance(value, float):
for row in rows:
if row[column] >= value:
list1.append(row)
else:
list2.append(row)
else:
for row in rows:
if row[column] == value:
list1.append(row)
else:
list2.append(row)
return list1, list2
复制代码
创建二进制决策树本质上就是递归划分输入空间的过程。
代码如下:
# gini()
def gini(rows):
# 计算gini的值(Calculate GINI)
length = len(rows)
results = calculateDiffCount(rows)
imp = 0.0
for i in results:
imp += results[i] / length * results[i] / length
return 1 - imp
复制代码
def buildDecisionTree(rows, evaluationFunction=gini):
# 递归建立决策树, 当gain=0,时停止回归
# build decision tree bu recursive function
# stop recursive function when gain = 0
# return tree
currentGain = evaluationFunction(rows)
column_lenght = len(rows[0])
rows_length = len(rows)
best_gain = 0.0
best_value = None
best_set = None
# choose the best gain
for col in range(column_lenght - 1):
col_value_set = set([x[col] for x in rows])
for value in col_value_set:
list1, list2 = splitDatas(rows, value, col)
p = len(list1) / rows_length
gain = currentGain - p * evaluationFunction(list1) - (1 - p) * evaluationFunction(list2)
if gain > best_gain:
best_gain = gain
best_value = (col, value)
best_set = (list1, list2)
dcY = {'impurity': '%.3f' % currentGain, 'sample': '%d' % rows_length}
#
# stop or not stop
if best_gain > 0:
trueBranch = buildDecisionTree(best_set[0], evaluationFunction)
falseBranch = buildDecisionTree(best_set[1], evaluationFunction)
return Tree(col=best_value[0], value = best_value[1], trueBranch = trueBranch, falseBranch=falseBranch, summary=dcY)
else:
return Tree(results=calculateDiffCount(rows), summary=dcY, data=rows)
复制代码
上面代码的功能是先找到数据集切分的最佳位置和分割数据集。之后通过递归构建出上面图片的整棵树。
在决策树的学习中,有时会造成决策树分支过多,这是就需要去掉一些分支,降低过度拟合。通过决策树的复杂度来避免过度拟合的过程称为剪枝。
后剪枝需要从训练集生成一棵完整的决策树,然后自底向上对非叶子节点进行考察。利用测试集判断是否将该节点对应的子树替换成叶节点。
代码如下:
def prune(tree, miniGain, evaluationFunction=gini):
# 剪枝 when gain < mini Gain, 合并(merge the trueBranch and falseBranch)
if tree.trueBranch.results == None:
prune(tree.trueBranch, miniGain, evaluationFunction)
if tree.falseBranch.results == None:
prune(tree.falseBranch, miniGain, evaluationFunction)
if tree.trueBranch.results != None and tree.falseBranch.results != None:
len1 = len(tree.trueBranch.data)
len2 = len(tree.falseBranch.data)
len3 = len(tree.trueBranch.data + tree.falseBranch.data)
p = float(len1) / (len1 + len2)
gain = evaluationFunction(tree.trueBranch.data + tree.falseBranch.data) - p * evaluationFunction(tree.trueBranch.data) - (1 - p) * evaluationFunction(tree.falseBranch.data)
if gain < miniGain:
tree.data = tree.trueBranch.data + tree.falseBranch.data
tree.results = calculateDiffCount(tree.data)
tree.trueBranch = None
tree.falseBranch = None
复制代码
当节点的gain小于给定的 mini Gain时则合并这两个节点.。
最后是构建树的代码:
if __name__ == '__main__':
dataSet = loadCSV()
decisionTree = buildDecisionTree(dataSet, evaluationFunction=gini)
prune(decisionTree, 0.4)
test_data = [5.9,3,4.2,1.5]
r = classify(test_data, decisionTree)
print(r)
复制代码
可以打印decisionTree可以构建出如如上的图片中的决策树。
后面找一组数据测试看能否得到正确的分类。
完整代码和数据集请查看:
github:CART
总结:
参考文章:
CART分类回归树分析与python实现
CART决策树(Decision Tree)的Python源码实现
本文链接:http://task.lmcjl.com/news/5134.html