百度360必应搜狗淘宝本站头条
当前位置:网站首页 > 编程字典 > 正文

第2章 监督学习(2.3.5决策树)《Python机器学习基础教程》之五

toyiye 2024-06-21 12:30 12 浏览 0 评论

%matplotlib inline

from preamble import *

2.3.5 决策树

决策树是广泛用于分类和回归任务的模型。本质上,它从一层层的if/else问题中进行学习,并得出结论。想象一下,你想要区分下面这四种动物:熊、鹰、企鹅和海豚。你的目标是通过提出尽可能少的if/else 问题来得到正确答案。你可能首先会问:这种动物有没有羽毛,这个问题会将可能的动物减少到只有两种。如果答案是“有”,你可以问下一个问题,帮你区分鹰和企鹅。例如,你可以问这种动物会不会飞。如果这种动物没有羽毛,那么可能是海豚或熊,所以你需要问一个问题来区分这两种动物——比如问这种动物有没有鳍。

In [36]:

from scipy.misc import imread

import matplotlib.pyplot as plt

def plot_animal_tree(ax=None):

import graphviz

if ax is None:

plt.figure(figsize=(8, 8),dpi=80)

ax = plt.gca()

mygraph = graphviz.Digraph(node_attr={'shape': 'box', 'fontname':'Microsoft YaHei'},

edge_attr={'labeldistance': "10.5",'fontname':'Microsoft YaHei'},

format="png")

mygraph.node("0", "有没有羽毛?")

mygraph.node("1", "会飞吗?")

mygraph.node("2", "有没有鳍?")

mygraph.node("3", "鹰")

mygraph.node("4", "企鹅")

mygraph.node("5", "海豚")

mygraph.node("6", "熊")

mygraph.edge("0", "1", label="有")

mygraph.edge("0", "2", label="没有")

mygraph.edge("1", "3", label="会")

mygraph.edge("1", "4", label="不会")

mygraph.edge("2", "5", label="有")

mygraph.edge("2", "6", label="没有")

mygraph.render("tmp")

ax.imshow(plt.imread("tmp.png"))

ax.set_axis_off()

plot_animal_tree()

plt.suptitle("图2-22 区分几种动物的决策树",y=0.2,fontsize=18);

用机器学习的语言来说就是,为了区分四类动物(鹰、企鹅、海豚和熊),我们利用三个特征(“有没有羽毛”“会不会飞”和“有没有鳍”)来构建一个模型。我们可以利用监督学习从数据中学习模型,而无需人为构建模型。

1.构造决策树

我们在图2-23 所示的二维分类数据集上构造决策树。这个数据集由2 个半月形组成,每个类别都包含50 个数据点。我们将这个数据集称为two_moons。

学习决策树,就是学习一系列if/else 问题,使我们能够以最快的速度得到正确答案。在机器学习中,这些问题叫作测试。数据通常并不是具有二元特征(是/ 否)的形式,而是表示为连续特征,比如图2-23 所示的二维数据集,其测试形式是:“特征i 的值是否大于a ?”

In [3]:

from mglearn.tools import discrete_scatter

from mglearn.plot_helpers import cm2

from sklearn.datasets import make_moons

x, y = make_moons(n_samples=100, noise=0.25, random_state=3)

plt.figure(figsize=(8,6))

ax = plt.gca()

discrete_scatter(x[:, 0], x[:, 1], y, ax=ax)

#ax.set_xticks(())

#ax.set_yticks(())

plt.suptitle("图2-23:用于构造决策树的two_moons 数据集",y=0.05,fontsize=18);

为了构造决策树,算法搜遍所有可能的测试,找出对目标变量来说信息量最大的那一个。图2-24 展示了选出的第一个测试。将数据集在x[1]=0.0596 处划分可以得到最多信息,它在最大程度上将类别0 中的点与类别1 中的点进行区分。顶结点(也叫根结点)表示整个数据集,包含属于类别0 的50 个点和属于类别1 的50 个点。通过测试x[1] <=0.0596 的真假来对数据集进行划分,在图中表示为一条水平分隔线。如果测试结果为真,那么将这个点分配给左结点,左结点里包含属于类别0 的2 个点和属于类别1 的32 个点。否则将这个点分配给右结点,右结点里包含属于类别0 的48 个点和属于类别1 的18 个点。这两个左右的树子结点对应于图2-24 中的顶部区域和底部区域。

尽管第一次划分已经对两个类别做了很好的区分,但底部区域仍包含属于类别0 的点,顶部区域也仍包含属于类别1 的点。我们可以在两个区域中重复寻找最佳测试的过程,从而构建出更准确的模型。图2-25 展示了信息量最大的下一次划分,这次划分是基于x[0] 做出的,分为左右两个区域。

In [4]:

mglearn.plots.plot_tree_progressive1000(max_depth=1)

plt.suptitle("图2-24:深度为1 的树的决策边界(左)与相应的树(右)",y=0.1,fontsize=18);

In [5]:

mglearn.plots.plot_tree_progressive1000(max_depth=2)

plt.suptitle("图2-25:深度为2 的树的决策边界(左)与相应的树(右)",y=0.1,fontsize=18);

这一递归过程生成一棵二元决策树,其中每个结点都包含一个测试。或者你可以将每个测试看成沿着一条轴对当前数据进行划分。这是一种将算法看作分层划分的观点。由于每个测试仅关注一个特征,所以划分后的区域边界始终与坐标轴平行。对数据反复进行递归划分,直到划分后的每个区域(决策树的每个叶结点)只包含单一目标值(单一类别或单一回归值)。如果树中某个叶结点所包含数据点的目标值都相同,那么这个叶结点就是纯的(pure)。这个数据集的最终划分结果见图2-26。

In [6]:

mglearn.plots.plot_tree_progressive1000(max_depth=9)

plt.suptitle("图2-26:深度为9 的树的决策边界(左)与相应的树(右)",y=0.1,fontsize=18);

决策树可以应用于分类问题也可以回归问题

想要对新数据点进行预测,首先要查看这个点位于特征空间划分的哪个区域,然后将该区域的多数目标值(如果是纯的叶结点,就是单一目标值)作为预测结果。从根结点开始对树进行遍历就可以找到这一区域,每一步向左还是向右取决于是否满足相应的测试。

决策树也可以用于回归任务,使用的方法完全相同。预测的方法是,基于每个结点的测试对树进行遍历,最终找到新数据点所属的叶结点。这一数据点的输出即为此叶结点中所有训练点的平均目标值。

2.控制决策树的复杂度

通常来说,构造决策树直到所有叶结点都是纯的叶结点,这会导致模型非常复杂,并且对训练数据高度过拟合。纯叶结点的存在说明这棵树在训练集上的精度是100%。训练集中的每个数据点都位于分类正确的叶结点中。在图2-26 的左图中可以看出过拟合。这并不是人们想象中决策边界的样子,这个决策边界过于关注远离同类别其他点的单个异常点。

防止过拟合有两种常见的策略:一种是及早停止树的生长,也叫预剪枝(pre-pruning);另一种是先构造树,但随后删除或折叠信息量很少的结点,也叫后剪枝(post-pruning)或剪枝(pruning)。预剪枝的限制条件可能包括限制树的最大深度、限制叶结点的最大数目,或者规定一个结点中数据点的最小数目来防止继续划分。

scikit-learn的决策树在DecisionTreeRegressor类和DecisonTreeClassifier类中实现,只有预剪枝。

我们在乳腺癌数据集上更详细地看一下预剪枝的效果。我们导入数据集并将其分为训练集和测试集。将树完全展开(树不断分支,直到所有叶结点都是纯的)。我们固定树的random_state,用于在内部解决平局问题:

In [7]:

# 存在过拟合

from sklearn.tree import DecisionTreeClassifier as Dtc

from sklearn.datasets import load_breast_cancer

from sklearn.model_selection import train_test_split

cancer = load_breast_cancer()

x_train,x_test,y_train,y_test = train_test_split(cancer.data,cancer.target,stratify=cancer.target,random_state=42)

tree = Dtc(random_state=0)

tree.fit(x_train,y_train)

print("Accuracy on training set:{:.3f}".format(tree.score(x_train,y_train)))

print("Accuracy on test set:{:.3f}".format(tree.score(x_test,y_test)))

Accuracy on training set:1.000

Accuracy on test set:0.937

不出所料,训练集上的精度是100%,这是因为叶结点都是纯的,树的深度很大,足以完美地记住训练数据的所有标签。测试集精度比线性模型略低,线性模型的精度约为95%。

如果不限制决策树的深度,它的深度和复杂度都可以变得特别大。因此,未剪枝的树容易过拟合,对新数据的泛化性能不佳。现在将预剪枝应用在决策树上。一种选择是在到达一定深度后停止树的展开。这里我们设置max_depth=4,这意味着只可以连续问4 个问题(参见图2-24 和图2-26)。限制树的深度可以减少过拟合。这会降低训练集的精度,但可以提高测试集的精度,代码如下:

In [38]:

# 使用预剪枝,max_depth=4

tree = Dtc(max_depth=4,random_state=0)

tree.fit(x_train,y_train)

print("Accuracy on training set:{:.3f}".format(tree.score(x_train,y_train)))

print("Accuracy on test set:{:.3f}".format(tree.score(x_test,y_test)))

Accuracy on training set:0.988

Accuracy on test set:0.951

3.分析决策树

我们可以利用tree 模块的export_graphviz 函数来将树可视化。这个函数会生成一个.dot 格式的文件,这是一种用于保存图形的文本文件格式。我们设置为结点添加颜色的选项,颜色表示每个结点中的多数类别,同时传入类别名称和特征名称,这样可以对树正确标记。

我们可以利用graphviz 模块读取这个文件并将其可视化,见图2-27:

In [39]:

from sklearn.tree import export_graphviz

from sklearn.externals.six import StringIO

dot_data = StringIO()

export_graphviz(tree,out_file=dot_data,class_names=['malignant','benign'],

feature_names=cancer.feature_names,impurity=False,filled=True)

import graphviz

data = dot_data.getvalue()

img = graphviz.Source(data,format="png")

#display(img) 不缩放

img.render("tree")

#plt.figure(figsize=(20,20)) #缩放图片

plt.figure(figsize=(20,12),dpi=300)

ax = plt.gca()

ax.set_axis_off()

plt.imshow(plt.imread("tree.png"))

plt.suptitle("图2-27:基于乳腺癌数据集构造的决策树的可视化",y=0.17,fontsize=18);

树的可视化有助于深入理解算法是如何进行预测的,也是易于向非专家解释的机器学习算法的优秀示例。不过,即使这里树的深度只有4 层,也有点太大了。深度更大的树(深度为10 并不罕见)更加难以理解。

一种观察树的方法可能有用,就是找出大部分数据的实际路径。图2-27 中每个结点的samples 给出了该结点中的样本个数,values 给出的是每个类别的样本个数。观察worst radius <= 16.795 分支右侧的子结点,我们发现它只包含8 个良性样本,但有134 个恶性样本。树的这一侧的其余分支只是利用一些更精细的区别将这8 个良性样本分离出来。在第一次划分右侧的142 个样本中,几乎所有样本(132 个)最后都进入最右侧的叶结点中。

再来看一下根结点的左侧子结点,对于worst radius > 16.795,我们得到25 个恶性样本和259 个良性样本。几乎所有良性样本最终都进入左数第二个叶结点中,大部分其他叶结点都只包含很少的样本。

4.树的特征重要性

查看整个树可能非常费劲,除此之外,还可以利用一些有用的属性来总结树的工作原理。其中最常用的是特征重要性(feature importance),它为每个特征对树的决策的重要性进行排序。对于每个特征来说,它都是一个介于0 和1 之间的数字,其中0 表示“根本没用到”,1 表示“完美预测目标值”。特征重要性的求和始终为1:

In [10]:

print("Teature importances:\n{}".format(tree.feature_importances_))

Teature importances:

[0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.01 0.048

0. 0. 0.002 0. 0. 0. 0. 0. 0.727 0.046 0. 0.

0.014 0. 0.018 0.122 0.012 0. ]

In [11]:

def plot_feature_importances_cancer(model):

n_features = cancer.data.shape[1]

plt.figure(figsize=(10,7.5))

plt.barh(range(n_features),model.feature_importances_,align='center')

plt.yticks(np.arange(n_features),cancer.feature_names)

plt.xlabel("Feature importance")

plt.ylabel("Feature")

plot_feature_importances_cancer(tree)

plt.suptitle("图2-28:在乳腺癌数据集上学到的决策树的特征重要性",y=0.05,fontsize=18);

这里我们看到,顶部划分用到的特征(“worst radius”)是最重要的特征。这也证实了我们在分析树时的观察结论,即第一层划分已经将两个类别区分得很好。但是,如果某个特征的feature_importance_ 很小,并不能说明这个特征没有提供任何信息。这只能说明该特征没有被树选中,可能是因为另一个特征也包含了同样的信息。

与线性模型的系数不同,特征重要性始终为正数,也不能说明该特征对应哪个类别。特征重要性告诉我们“worst radius”(最大半径)特征很重要,但并没有告诉我们半径大表示样本是良性还是恶性。事实上,在特征和类别之间可能没有这样简单的关系,你可以在下面的例子中看出这一点(图2-29 和图2-30):

In [41]:

from sklearn.datasets import make_blobs

from sklearn.tree import DecisionTreeClassifier, export_graphviz

from mglearn.tools import discrete_scatter

from mglearn.plot_2d_separator import plot_2d_separator

x, y = make_blobs(centers=4, random_state=8)

y = y % 2

plt.figure(figsize=(10,7.5))

discrete_scatter(x[:, 0], x[:, 1], y)

plt.legend(["Class 0", "Class 1"], loc="best")

tree = DecisionTreeClassifier(random_state=0).fit(x, y)

plot_2d_separator(tree, x, linestyle="dashed")

plt.suptitle("图2-29:一个二维数据集(y 轴上的特征与类别标签是非单调的关系)与决策树给出的决策边界",y=0.1,fontsize=18);

In [42]:

export_graphviz(tree, out_file="mytree.dot", impurity=True, filled=True)

with open("mytree.dot") as f:

dot_graph = f.read()

print("Feature importances: %s" % tree.feature_importances_)

display(graphviz.Source(dot_graph))

plt.figure(figsize=(20,1))

ax = plt.gca()

ax.set_axis_off()

plt.suptitle("图2-30:从图2-29 的数据中学到的决策树",x=0.2,y=1,fontsize=28);

Feature importances: [0. 1.]

图2-30:从图2-29 的数据中学到的决策树

上图显示的是有两个特征和两个类别的数据集。这里所有信息都包含在x[1] 中,没有用到x[0]。但x[1] 和输出类别之间并不是单调关系,即我们不能这么说:“较大的x[1] 对应类别0,较小的x[1] 对应类别1”(反之亦然)。

虽然我们主要讨论的是用于分类的决策树,但对用于回归的决策树来说,所有内容都是类似的,在DecisionTreeRegressor 中实现。回归树的用法和分析与分类树非常类似。但在将基于树的模型用于回归时,我们想要指出它的一个特殊性质。DecisionTreeRegressor(以及其他所有基于树的回归模型)不能外推(extrapolate),也不能在训练数据范围之外进行预测。

我们利用计算机内存(RAM)历史价格的数据集来更详细地研究这一点。图2-31 给出了这个数据集的图像,x 轴为日期,y 轴为那一年1 兆字节(MB)RAM 的价格:

In [14]:

# 用对数坐标绘制RAM价格的历史发展

import pandas as pd

ram_prices = pd.read_csv('data/ram_price.csv')

plt.figure(figsize=(8,6))

#plt.semilogy(ram_prices.date,ram_prices.price)

plt.plot(ram_prices.date,np.log10(ram_prices.price))

plt.xlabel("Year")

plt.ylabel("Price in $/Mbyte");

plt.suptitle("图2-31:用对数坐标绘制RAM 价格的历史发展",y=0.05,fontsize=18);

注意y 轴的对数刻度。在用对数坐标绘图时,二者的线性关系看起来非常好,所以预测应该相对比较容易,除了一些不平滑之处之外。我们将利用2000 年前的历史数据来预测2000 年后的价格,只用日期作为特征。我们将对比两个简单的模型:DecisionTreeRegressor 和LinearRegression。我们对价格取对数,使得二者关系的线性相对更好。这对DecisionTreeRegressor 不会产生什么影响,但对LinearRegression 的影响却很大(我们将在第4 章中进一步讨论)。训练模型并做出预测之后,我们应用指数映射来做对数变换的逆运算。为了便于可视化,我们这里对整个数据集进行预测,但是为了定量评估,我们将只考虑测试数据集:

In [15]:

from sklearn.tree import DecisionTreeRegressor as Dtr

from sklearn.linear_model import LinearRegression

data_train = ram_prices[ram_prices.date < 2000]

data_test = ram_prices[ram_prices.date >= 2000]

x_train = data_train.date[:,np.newaxis]

y_train = np.log(data_train.price)

tree = Dtr().fit(x_train,y_train)

linear_reg = LinearRegression().fit(x_train,y_train)

x_all = ram_prices.date[:,np.newaxis]

pred_tree = tree.predict(x_all)

pred_lr = linear_reg.predict(x_all)

price_tree = np.exp(pred_tree)

price_lr = np.exp(pred_lr)

plt.figure(figsize=(8,6))

plt.plot(data_train.date,np.log10(data_train.price),label="Training data")

plt.plot(data_test.date,np.log10(data_test.price),label="Test data")

plt.plot(ram_prices.date,np.log10(price_tree),label="Tree prediction")

plt.plot(ram_prices.date,np.log10(price_lr),label="Linear prediction")

plt.legend()

plt.suptitle("图2-32 线性模型和回归树对RAM价格数据的预测结果对比",y=0.05,fontsize=18);

两个模型之间的差异非常明显。线性模型用一条直线对数据做近似,这是我们所知道的。这条线对测试数据(2000 年后的价格)给出了相当好的预测,不过忽略了训练数据和测试数据中一些更细微的变化。与之相反,树模型完美预测了训练数据。由于我们没有限制树的复杂度,因此它记住了整个数据集。但是,一旦输入超出了模型训练数据的范围,模型就只能持续预测最后一个已知数据点。树不能在训练数据的范围之外生成“新的”响应。所有基于树的模型都有这个缺点。

5.优点、缺点和参数

如前所述,控制决策树模型复杂度的参数是预剪枝参数,它在树完全展开之前停止树的构造。通常来说,选择一种预剪枝策略(设置max_depth、max_leaf_nodes 或min_samples_leaf)足以防止过拟合。

与前面讨论过的许多算法相比,决策树有两个优点:一是得到的模型很容易可视化,非专家也很容易理解(至少对于较小的树而言);二是算法完全不受数据缩放的影响。由于每个特征被单独处理,而且数据的划分也不依赖于缩放,因此决策树算法不需要特征预处理,比如归一化或标准化。特别是特征的尺度完全不一样时或者二元特征和连续特征同时存在时,决策树的效果很好。

决策树的主要缺点在于,即使做了预剪枝,它也经常会过拟合,泛化性能很差。因此,在大多数应用中,往往使用下面介绍的集成方法来替代单棵决策树。

相关推荐

为何越来越多的编程语言使用JSON(为什么编程)

JSON是JavascriptObjectNotation的缩写,意思是Javascript对象表示法,是一种易于人类阅读和对编程友好的文本数据传递方法,是JavaScript语言规范定义的一个子...

何时在数据库中使用 JSON(数据库用json格式存储)

在本文中,您将了解何时应考虑将JSON数据类型添加到表中以及何时应避免使用它们。每天?分享?最新?软件?开发?,Devops,敏捷?,测试?以及?项目?管理?最新?,最热门?的?文章?,每天?花?...

MySQL 从零开始:05 数据类型(mysql数据类型有哪些,并举例)

前面的讲解中已经接触到了表的创建,表的创建是对字段的声明,比如:上述语句声明了字段的名称、类型、所占空间、默认值和是否可以为空等信息。其中的int、varchar、char和decimal都...

JSON对象花样进阶(json格式对象)

一、引言在现代Web开发中,JSON(JavaScriptObjectNotation)已经成为数据交换的标准格式。无论是从前端向后端发送数据,还是从后端接收数据,JSON都是不可或缺的一部分。...

深入理解 JSON 和 Form-data(json和formdata提交区别)

在讨论现代网络开发与API设计的语境下,理解客户端和服务器间如何有效且可靠地交换数据变得尤为关键。这里,特别值得关注的是两种主流数据格式:...

JSON 语法(json 语法 priority)

JSON语法是JavaScript语法的子集。JSON语法规则JSON语法是JavaScript对象表示法语法的子集。数据在名称/值对中数据由逗号分隔花括号保存对象方括号保存数组JS...

JSON语法详解(json的语法规则)

JSON语法规则JSON语法是JavaScript对象表示法语法的子集。数据在名称/值对中数据由逗号分隔大括号保存对象中括号保存数组注意:json的key是字符串,且必须是双引号,不能是单引号...

MySQL JSON数据类型操作(mysql的json)

概述mysql自5.7.8版本开始,就支持了json结构的数据存储和查询,这表明了mysql也在不断的学习和增加nosql数据库的有点。但mysql毕竟是关系型数据库,在处理json这种非结构化的数据...

JSON的数据模式(json数据格式示例)

像XML模式一样,JSON数据格式也有Schema,这是一个基于JSON格式的规范。JSON模式也以JSON格式编写。它用于验证JSON数据。JSON模式示例以下代码显示了基本的JSON模式。{"...

前端学习——JSON格式详解(后端json格式)

JSON(JavaScriptObjectNotation)是一种轻量级的数据交换格式。易于人阅读和编写。同时也易于机器解析和生成。它基于JavaScriptProgrammingLa...

什么是 JSON:详解 JSON 及其优势(什么叫json)

现在程序员还有谁不知道JSON吗?无论对于前端还是后端,JSON都是一种常见的数据格式。那么JSON到底是什么呢?JSON的定义...

PostgreSQL JSON 类型:处理结构化数据

PostgreSQL提供JSON类型,以存储结构化数据。JSON是一种开放的数据格式,可用于存储各种类型的值。什么是JSON类型?JSON类型表示JSON(JavaScriptO...

JavaScript:JSON、三种包装类(javascript 包)

JOSN:我们希望可以将一个对象在不同的语言中进行传递,以达到通信的目的,最佳方式就是将一个对象转换为字符串的形式JSON(JavaScriptObjectNotation)-JS的对象表示法...

Python数据分析 只要1分钟 教你玩转JSON 全程干货

Json简介:Json,全名JavaScriptObjectNotation,JSON(JavaScriptObjectNotation(记号、标记))是一种轻量级的数据交换格式。它基于J...

比较一下JSON与XML两种数据格式?(json和xml哪个好)

JSON(JavaScriptObjectNotation)和XML(eXtensibleMarkupLanguage)是在日常开发中比较常用的两种数据格式,它们主要的作用就是用来进行数据的传...

取消回复欢迎 发表评论:

请填写验证码