百度360必应搜狗淘宝本站头条
当前位置:网站首页 > 编程字典 > 正文

决策树算法应用及结果解读

toyiye 2024-06-21 12:24 7 浏览 0 评论

引言

本文是我写的人工智能系列的第 8 篇文章,文末有前面 7 篇文章的链接,推荐你阅读、分享和交流。

1. 决策树算法简介

决策树是一种应用非常广泛的算法,比如语音识别、人脸识别、医疗诊断、模式识别等。

决策树算法既可以解决分类问题(对应的目标值是类别型的数据),也能解决回归问题(输出结果也可以是连续的数值)。

相比其他算法,决策树有一个非常明显的优势,就是可以很直观地进行可视化,分类规则好理解,让非专业的人也容易看明白。

比如某个周末,你根据天气等情况决定是否出门,如果降雨就不出门,否则看是否有雾霾……这个决策的过程,可以画成这样一颗树形图:

下面我们以 sklearn 中的葡萄酒数据集为例,给定一些数据指标,比如酒精度等,利用决策树算法,可以判断出葡萄酒的类别。

2. 加载数据

为了方便利用图形进行可视化演示,我们只选取其中 2 个特征:第 1 个特征(酒精度)和第 7 个特征(黄酮量),并绘制出 3 类葡萄酒相应的散点图。

import numpy as np
import matplotlib.pyplot as plt
from sklearn import datasets

# 加载葡萄酒的数据集
wine = datasets.load_wine()

# 为了方便可视化,只选取 2 个特征
X = wine.data[:, [0, 6]]
y = wine.target

# 绘制散点图
plt.scatter(X[y==0, 0], X[y==0, 1])
plt.scatter(X[y==1, 0], X[y==1, 1])
plt.scatter(X[y==2, 0], X[y==2, 1])
plt.show()

在上面的散点图中,颜色代表葡萄酒的类别,横轴代表酒精度,纵轴代表黄酮量。

3. 调用算法

和调用其他算法的方法一样,我们先把数据集拆分为训练集和测试集,然后指定相关参数,这里我们指定决策树的最大深度等于 2,并对算法进行评分。

from sklearn.model_selection import train_test_split
from sklearn import tree

# 拆分为训练集和测试集
X_train, X_test, y_train, y_test = train_test_split(X, y, random_state=0)

# 调用决策树分类算法
dtc = tree.DecisionTreeClassifier(max_depth=2)
dtc.fit(X_train, y_train)

# 算法评分
print('训练得分:', dtc.score(X_train, y_train))
print('测试得分:', dtc.score(X_test, y_test))
训练得分:0.9172932330827067
测试得分:0.8666666666666667

从上面的结果可以看出,决策树算法的训练得分和测试得分都还不错。

假如设置 max_depth = 1,那么算法评分很低,就会出现欠拟合的问题。

假如设置 max_depth = 1 0,那么虽然算法的评分变高了,但是决策树变得过于复杂,就会出现过拟合的问题。

关于模型复杂度的问题讨论,可以参考:模型越复杂越好吗?

4. 决策边界

为了更加直观地看到算法的分类效果,我们定义一个绘制决策边界的函数,画出分类的边界线。

from matplotlib.colors import ListedColormap

# 定义绘制决策边界的函数
def plot_decision_boundary(model, axis):
    
    x0, x1 = np.meshgrid(
        np.linspace(axis[0], axis[1], int((axis[1]-axis[0])*100)).reshape(-1,1),
        np.linspace(axis[2], axis[3], int((axis[3]-axis[2])*100)).reshape(-1,1)
    )
    X_new = np.c_[x0.ravel(), x1.ravel()]
    
    y_predict = model.predict(X_new)
    zz = y_predict.reshape(x0.shape)
    
    custom_cmap = ListedColormap(['#EF9A9A','#FFF59D','#90CAF9'])
    
    plt.contourf(x0, x1, zz, cmap=custom_cmap)
    
# 绘制决策边界
plot_decision_boundary(dtc, axis=[11, 15, 0, 6])
plt.scatter(X[y==0, 0], X[y==0, 1])
plt.scatter(X[y==1, 0], X[y==1, 1])
plt.scatter(X[y==2, 0], X[y==2, 1])
plt.show()

从图中也可以直观地看出,大部分数据点的分类是基本准确的,这也说明决策树算法的效果还不错。

5. 树形图

为了能够更加直观地理解决策树算法,我们可以用树形图来展示算法的结果。

# 导入相关库,需要先安装 graphviz 和 pydotplus,并在电脑中 Graphviz 软件
import pydotplus
from sklearn.tree import export_graphviz
from IPython.display import Image
from io import StringIO

# 将对象写入内存中
dot_data = StringIO()

# 生成决策树结构
tree.export_graphviz(dtc, class_names=wine.target_names,
                     feature_names=[wine.feature_names[0], wine.feature_names[6]],
                     rounded=True, filled=True, out_file = dot_data)

# 生成树形图并展示出来
graph = pydotplus.graph_from_dot_data(dot_data.getvalue())
Image(graph.create_png())

6. 结果解读

从上面的树形图来看,在葡萄酒数据的训练集中,有 133 个数据,划分为 3 个类别,数量分别是 43、50、40 个,对应的标签分别是 class_0、class_1、class_2,其中 class_1 的数量最多,所以最上面的根节点认为,类别为 class_1 的可能性最大,Gini 系数为 0.664,它是利用下面的公式计算出来的:

1 - (43/133)**2 - (50/133)**2 - (40/133)**2

在决策树算法中,Gini 系数代表样本的不确定性。 当每个类别的数量越趋近于平均值,Gini 系数就越大,也就越不确定。

比如扔硬币的游戏,在一般情况下,正反两面的概率都是 50%,此时 Gini 系数等于 0.5,你猜中的概率也是 50%;假如你对硬币做了手脚,把两面都变成正面图案,此时Gini 系数等于 0, 也就是说,不确定性为 0,你能明确地知道肯定是正面。

在上面葡萄酒的例子中,当黄酮量 <= 1.575 时,有 49 个样本,3 个类别的数量分别是 0、9、40 个,其中 class_2 的数量最多,Gini 系数为 0.3,比上面的节点要低,说明分类结果变得更加确定。当酒精量 > 12.41 时,有 39 个样本,3 个类别的数量分别是 0、2、37个,Gini 系数为 0.097,此时分类结果变得更加确定为 class_2。

树形图中其他节点的结果含义类似,在此不再赘述。

小结

本文介绍了决策树算法的应用,以葡萄酒数据集为例,演示了决策树算法的实现过程,绘制了直观易懂的决策边界和树形图,并对决策结果做了详细解读。

虽然决策树算法有很多优点,比如高效、易懂,但是也有它的不足之处,比如当参数设置不当时,很容易出现过拟合的问题。

为了避免决策树算法出现过拟合的问题,可以使用「 集成学习 」的方法,融合多种不同的算法,也就是俗话讲的「三个臭皮匠,赛过诸葛亮」。

相关推荐

为何越来越多的编程语言使用JSON(为什么编程)

JSON是JavascriptObjectNotation的缩写,意思是Javascript对象表示法,是一种易于人类阅读和对编程友好的文本数据传递方法,是JavaScript语言规范定义的一个子...

何时在数据库中使用 JSON(数据库用json格式存储)

在本文中,您将了解何时应考虑将JSON数据类型添加到表中以及何时应避免使用它们。每天?分享?最新?软件?开发?,Devops,敏捷?,测试?以及?项目?管理?最新?,最热门?的?文章?,每天?花?...

MySQL 从零开始:05 数据类型(mysql数据类型有哪些,并举例)

前面的讲解中已经接触到了表的创建,表的创建是对字段的声明,比如:上述语句声明了字段的名称、类型、所占空间、默认值和是否可以为空等信息。其中的int、varchar、char和decimal都...

JSON对象花样进阶(json格式对象)

一、引言在现代Web开发中,JSON(JavaScriptObjectNotation)已经成为数据交换的标准格式。无论是从前端向后端发送数据,还是从后端接收数据,JSON都是不可或缺的一部分。...

深入理解 JSON 和 Form-data(json和formdata提交区别)

在讨论现代网络开发与API设计的语境下,理解客户端和服务器间如何有效且可靠地交换数据变得尤为关键。这里,特别值得关注的是两种主流数据格式:...

JSON 语法(json 语法 priority)

JSON语法是JavaScript语法的子集。JSON语法规则JSON语法是JavaScript对象表示法语法的子集。数据在名称/值对中数据由逗号分隔花括号保存对象方括号保存数组JS...

JSON语法详解(json的语法规则)

JSON语法规则JSON语法是JavaScript对象表示法语法的子集。数据在名称/值对中数据由逗号分隔大括号保存对象中括号保存数组注意:json的key是字符串,且必须是双引号,不能是单引号...

MySQL JSON数据类型操作(mysql的json)

概述mysql自5.7.8版本开始,就支持了json结构的数据存储和查询,这表明了mysql也在不断的学习和增加nosql数据库的有点。但mysql毕竟是关系型数据库,在处理json这种非结构化的数据...

JSON的数据模式(json数据格式示例)

像XML模式一样,JSON数据格式也有Schema,这是一个基于JSON格式的规范。JSON模式也以JSON格式编写。它用于验证JSON数据。JSON模式示例以下代码显示了基本的JSON模式。{"...

前端学习——JSON格式详解(后端json格式)

JSON(JavaScriptObjectNotation)是一种轻量级的数据交换格式。易于人阅读和编写。同时也易于机器解析和生成。它基于JavaScriptProgrammingLa...

什么是 JSON:详解 JSON 及其优势(什么叫json)

现在程序员还有谁不知道JSON吗?无论对于前端还是后端,JSON都是一种常见的数据格式。那么JSON到底是什么呢?JSON的定义...

PostgreSQL JSON 类型:处理结构化数据

PostgreSQL提供JSON类型,以存储结构化数据。JSON是一种开放的数据格式,可用于存储各种类型的值。什么是JSON类型?JSON类型表示JSON(JavaScriptO...

JavaScript:JSON、三种包装类(javascript 包)

JOSN:我们希望可以将一个对象在不同的语言中进行传递,以达到通信的目的,最佳方式就是将一个对象转换为字符串的形式JSON(JavaScriptObjectNotation)-JS的对象表示法...

Python数据分析 只要1分钟 教你玩转JSON 全程干货

Json简介:Json,全名JavaScriptObjectNotation,JSON(JavaScriptObjectNotation(记号、标记))是一种轻量级的数据交换格式。它基于J...

比较一下JSON与XML两种数据格式?(json和xml哪个好)

JSON(JavaScriptObjectNotation)和XML(eXtensibleMarkupLanguage)是在日常开发中比较常用的两种数据格式,它们主要的作用就是用来进行数据的传...

取消回复欢迎 发表评论:

请填写验证码