百度360必应搜狗淘宝本站头条
当前位置:网站首页 > 编程字典 > 正文

Python实现机器学习算法——决策树之ID3算法

toyiye 2024-06-21 12:20 9 浏览 0 评论

作为机器学习中的一大类模型,树模型一直以来都颇受学界和业界的重视。目前无论是各大比赛各种大杀器的XGBoost、lightgbm还是像随机森林、Adaboost等典型集成学习模型,都是以决策树模型为基础的。

传统的经典决策树算法包括ID3算法、C4.5算法以及GBDT的基分类器CART算法。

三大经典决策树算法最主要的区别在于其特征选择准则的不同。ID3算法选择特征的依据是信息增益、C4.5是信息增益比,而CART则是Gini指数。

作为一种基础的分类和回归方法,决策树可以有如下两种理解方式。一种是我们可以将决策树看作是一组if-then规则的集合,另一种则是给定特征条件下类的条件概率分布。

根据上述两种理解方式,我们可以将决策树的本质视作从训练数据集中归纳出一组分类规则,也可以将其看作是根据训练数据集估计条件概率模型。

整个决策树的学习过程就是一个递归地选择最优特征,并根据该特征对数据集进行划分,使得各个样本都得到一个最好的分类的过程。


ID3算法理论

所以这里的关键在于如何选择最优特征对数据集进行划分。答案就是前面提到的信息增益、信息增益比和Gini指数。因为本篇针对的是ID3算法,所以这里仅对信息增益进行详细的表述。

在讲信息增益之前,这里我们必须先介绍下熵的概念。在信息论里面,熵是一种表示随机变量不确定性的度量方式。若离散随机变量X的概率分布为:

则随机变量X的熵定义为:

同理,对于连续型随机变量Y,其熵可定义为:

当给定随机变量X的条件下随机变量Y的熵可定义为条件熵H(Y|X):

所谓信息增益就是数据在得到特征X的信息时使得类Y的信息不确定性减少的程度。假设数据集D的信息熵为H(D),给定特征A之后的条件熵为H(D|A),则特征A对于数据集的信息增益g(D,A)可表示为:

g(D,A) = H(D) - H(D|A)

信息增益越大,则该特征对数据集确定性贡献越大,表示该特征对数据有较强的分类能力。信息增益的计算示例如下:
1).计算目标特征的信息熵。

2).计算加入某个特征之后的条件熵。


3).计算信息增益。


以上就是ID3算法的核心理论部分,至于如何基于ID3构造决策树,我们在代码实例中来看。



ID3算法实现

先读入示例数据集:

import numpy as np
import pandas as pd
from math import log

df = pd.read_csv('./example_data.csv')
df

定义熵的计算函数:

def entropy(ele):    
    '''
    function: Calculating entropy value.
    input: A list contain categorical value.
    output: Entropy value.
    entropy = - sum(p * log(p)), p is a prob value.
    '''
    # Calculating the probability distribution of list value
    probs = [ele.count(i)/len(ele) for i in set(ele)]    
    # Calculating entropy value
    entropy = -sum([prob*log(prob, 2) for prob in probs])    
    return entropy

计算示例:

然后我们需要定义根据特征和特征值进行数据划分的方法:

def split_dataframe(data, col):    
    '''
    function: split pandas dataframe to sub-df based on data and column.
    input: dataframe, column name.
    output: a dict of splited dataframe.
    '''
    # unique value of column
    unique_values = data[col].unique()    
    # empty dict of dataframe
    result_dict = {elem : pd.DataFrame for elem in unique_values}    
    # split dataframe based on column value
    for key in result_dict.keys():
        result_dict[key] = data[:][data[col] == key]    
    return result_dict

根据temp和其三个特征值得数据集划分示例:

然后就是根据熵计算公式和数据集划分方法计算信息增益来选择最佳特征的过程:

def choose_best_col(df, label):    
    '''
    funtion: choose the best column based on infomation gain.
    input: datafram, label
    output: max infomation gain, best column, 
            splited dataframe dict based on best column.
    '''
    # Calculating label's entropy
    entropy_D = entropy(df[label].tolist())    
    # columns list except label
    cols = [col for col in df.columns if col not in [label]]    
    # initialize the max infomation gain, best column and best splited dict
    max_value, best_col = -999, None
    max_splited = None
    # split data based on different column
    for col in cols:
        splited_set = split_dataframe(df, col)
        entropy_DA = 0
        for subset_col, subset in splited_set.items():            
            # calculating splited dataframe label's entropy
            entropy_Di = entropy(subset[label].tolist())            
            # calculating entropy of current feature
            entropy_DA += len(subset)/len(df) * entropy_Di        
        # calculating infomation gain of current feature
        info_gain = entropy_D - entropy_DA        
        if info_gain > max_value:
            max_value, best_col = info_gain, col
            max_splited = splited_set    
        return max_value, best_col, max_splited

最先选到的信息增益最大的特征是outlook:

决策树基本要素定义好后,我们即可根据以上函数来定义一个ID3算法类,在类里面定义构造ID3决策树的方法:

class ID3Tree:    
    # define a Node class
    class Node:        
        def __init__(self, name):
            self.name = name
            self.connections = {}    
            
        def connect(self, label, node):
            self.connections[label] = node    
        
    def __init__(self, data, label):
        self.columns = data.columns
        self.data = data
        self.label = label
        self.root = self.Node("Root")    
    
    # print tree method
    def print_tree(self, node, tabs):
        print(tabs + node.name)        
        for connection, child_node in node.connections.items():
            print(tabs + "\t" + "(" + connection + ")")
            self.print_tree(child_node, tabs + "\t\t")    
    
    def construct_tree(self):
        self.construct(self.root, "", self.data, self.columns)    
    
    # construct tree
    def construct(self, parent_node, parent_connection_label, input_data, columns):
        max_value, best_col, max_splited = choose_best_col(input_data[columns], self.label)        
        if not best_col:
            node = self.Node(input_data[self.label].iloc[0])
            parent_node.connect(parent_connection_label, node)            
        return

        node = self.Node(best_col)
        parent_node.connect(parent_connection_label, node)

        new_columns = [col for col in columns if col != best_col]        
        # Recursively constructing decision trees
        for splited_value, splited_data in max_splited.items():
            self.construct(node, splited_value, splited_data, new_columns)

根据上述代码和示例数据集构造一个ID3决策树:

以上便是ID3算法的手写过程。sklearn中tree模块为我们提供了决策树的实现方式,参考代码如下:

from sklearn.datasets import load_iris
from sklearn import tree
import graphviz

iris = load_iris()
# criterion选择entropy,这里表示选择ID3算法
clf = tree.DecisionTreeClassifier(criterion='entropy', splitter='best')
clf = clf.fit(iris.data, iris.target)

dot_data = tree.export_graphviz(clf, out_file=None,
                               feature_names=iris.feature_names,
                               class_names=iris.target_names,
                               filled=True, 
                               rounded=True,
                               special_characters=True)
graph = graphviz.Source(dot_data)
graph

原文参考资料:

李航 统计学习方法

https://github.com/heolin123/id3/blob/master

https://mp.weixin.qq.com/s/6ixsCP8dvNYfqhQYUbnNHw

相关推荐

为何越来越多的编程语言使用JSON(为什么编程)

JSON是JavascriptObjectNotation的缩写,意思是Javascript对象表示法,是一种易于人类阅读和对编程友好的文本数据传递方法,是JavaScript语言规范定义的一个子...

何时在数据库中使用 JSON(数据库用json格式存储)

在本文中,您将了解何时应考虑将JSON数据类型添加到表中以及何时应避免使用它们。每天?分享?最新?软件?开发?,Devops,敏捷?,测试?以及?项目?管理?最新?,最热门?的?文章?,每天?花?...

MySQL 从零开始:05 数据类型(mysql数据类型有哪些,并举例)

前面的讲解中已经接触到了表的创建,表的创建是对字段的声明,比如:上述语句声明了字段的名称、类型、所占空间、默认值和是否可以为空等信息。其中的int、varchar、char和decimal都...

JSON对象花样进阶(json格式对象)

一、引言在现代Web开发中,JSON(JavaScriptObjectNotation)已经成为数据交换的标准格式。无论是从前端向后端发送数据,还是从后端接收数据,JSON都是不可或缺的一部分。...

深入理解 JSON 和 Form-data(json和formdata提交区别)

在讨论现代网络开发与API设计的语境下,理解客户端和服务器间如何有效且可靠地交换数据变得尤为关键。这里,特别值得关注的是两种主流数据格式:...

JSON 语法(json 语法 priority)

JSON语法是JavaScript语法的子集。JSON语法规则JSON语法是JavaScript对象表示法语法的子集。数据在名称/值对中数据由逗号分隔花括号保存对象方括号保存数组JS...

JSON语法详解(json的语法规则)

JSON语法规则JSON语法是JavaScript对象表示法语法的子集。数据在名称/值对中数据由逗号分隔大括号保存对象中括号保存数组注意:json的key是字符串,且必须是双引号,不能是单引号...

MySQL JSON数据类型操作(mysql的json)

概述mysql自5.7.8版本开始,就支持了json结构的数据存储和查询,这表明了mysql也在不断的学习和增加nosql数据库的有点。但mysql毕竟是关系型数据库,在处理json这种非结构化的数据...

JSON的数据模式(json数据格式示例)

像XML模式一样,JSON数据格式也有Schema,这是一个基于JSON格式的规范。JSON模式也以JSON格式编写。它用于验证JSON数据。JSON模式示例以下代码显示了基本的JSON模式。{"...

前端学习——JSON格式详解(后端json格式)

JSON(JavaScriptObjectNotation)是一种轻量级的数据交换格式。易于人阅读和编写。同时也易于机器解析和生成。它基于JavaScriptProgrammingLa...

什么是 JSON:详解 JSON 及其优势(什么叫json)

现在程序员还有谁不知道JSON吗?无论对于前端还是后端,JSON都是一种常见的数据格式。那么JSON到底是什么呢?JSON的定义...

PostgreSQL JSON 类型:处理结构化数据

PostgreSQL提供JSON类型,以存储结构化数据。JSON是一种开放的数据格式,可用于存储各种类型的值。什么是JSON类型?JSON类型表示JSON(JavaScriptO...

JavaScript:JSON、三种包装类(javascript 包)

JOSN:我们希望可以将一个对象在不同的语言中进行传递,以达到通信的目的,最佳方式就是将一个对象转换为字符串的形式JSON(JavaScriptObjectNotation)-JS的对象表示法...

Python数据分析 只要1分钟 教你玩转JSON 全程干货

Json简介:Json,全名JavaScriptObjectNotation,JSON(JavaScriptObjectNotation(记号、标记))是一种轻量级的数据交换格式。它基于J...

比较一下JSON与XML两种数据格式?(json和xml哪个好)

JSON(JavaScriptObjectNotation)和XML(eXtensibleMarkupLanguage)是在日常开发中比较常用的两种数据格式,它们主要的作用就是用来进行数据的传...

取消回复欢迎 发表评论:

请填写验证码