百度360必应搜狗淘宝本站头条
当前位置:网站首页 > 编程字典 > 正文

CatBoost机器学习模型的介绍与实例

toyiye 2024-07-06 00:16 10 浏览 0 评论

CatBoost是Yandex开发的梯度提升机器学习算法。

梯度提升是一种迭代算法,通常基于决策树。首先,建立一个基础模型,该模型具有较高的误差。考虑到该模型的误差,再建立另一个模型,通过这种方式,经过数千次迭代后,误差被最小化。

我们知道,梯度提升技术通常比神经网络在异构数据集上提供更好的结果。异构数据是分类、数字和文本特征的混合。神经网络通常更擅长同质数据。梯度提升在数据具有异构属性的生产和比赛中大量使用。这就是为什么在Kaggle比赛中经常使用XGBoost的原因。

CatBoost算法试图从给定的数据中获得最大的信息。因此,它可以很好地处理小型机器学习数据集。

与其提升模型相比,我们可以说 CatBoost 的表现相当出色。默认超参数取决于机器学习数据集。

与其提升算法不同,CatBoost使用对称全二叉树。这样一来,树是更简单的结构,我们也就避免了过度拟合的危险。此外,由于我们的基础模型结构简单,我们有更快的预测器。

有序目标编码

CatBoost 的标志之一是它处理分类特征的方法(CatBoost 是 Categorical Boosting)。如果分类特征在您的数据集中占主导地位,CatBoost 是我们应该首先尝试的算法之一。

通常情况下,one hot方法被应用于分类特征。通过这样的编码方法,我们可以从单个特征中得到多个特征

CatBoost的优点是它可以处理开箱即用的数据。目标编码过程中可能发生数据泄漏。也就是说,目标特征信息不应该泄漏到模型中。为了防止这种情况,CatBoost使用了一种智能方法,它进行某种基于目标但有序的编码。

例如,假设我们有一个分类特征和目标特征,如下例所示:

首先,打乱和重新排列数据的顺序。让我们假设上面的顺序是重新排列的。

Prior 是传递给算法的参数,通常为 0.5。current count是训练数据集中具有当前类别组的所在行之前的目标总数。记住,它是基于目标的,所以我们应该寻找具有相同目标值的行。例如,我们对第五行进行编码:

如果第三行的目标值为0,则

正如你所注意到的,我们会对数据集中已经重新排序的每个Germany有不同的数值。

最小方差采样

编码后,CatBoost使用一种称为MVS的采样方法,即最小方差采样。这意味着加权采样应用于树的层次。它会将概率分配给它必须选择的观测值,以便最大限度地提高准确性。

Python演示

按照库的命令;

pip install catboost

使用库中的Amazon机器学习数据集:

import os
import pandas as pd
import numpy as np
np.set_printoptions(precision=4)
import catboost
print(catboost.__version__)

#dataset

from catboost.datasets import amazon
(train_df, test_df) = amazon()

进行数据预处理,Python代码如下;

y = train_df.ACTION
X = train_df.drop('ACTION', axis=1)

#all the features are categorical
cat_features = list(range(0, X.shape[1]))
print(cat_features)

#unbalanced labels
print('Labels: {}'.format(set(y)))
print('Zero count = {}, One count = {}'.format(len(y) - sum(y), sum(y)))

与其他模型不同,我们将分类特征索引作为参数传递给训练对象。

#training
from catboost import CatBoostClassifier
model = CatBoostClassifier(iterations=100)
model.fit(X, y, cat_features=cat_features, verbose=10)

我们来进行预测,注意第一列是0的概率,第二列是1的概率。

model.predict_proba(X)
#输出: array([[0.0098, 0.9902],
       [0.0101, 0.9899],
       [0.0579, 0.9421],
       ...,
       [0.0118, 0.9882],
       [0.1891, 0.8109],
       [0.0235, 0.9765]])

我们已经用 CatBoost 训练了一个机器学习模型。下一步,让我们进一步完善我们的模型。在这里,我们的机器学习数据集是不平衡的。在创建具有不平衡数据集的模型时,使用权重通常是首选的方法。创建一个具有权重的列,给出了稀有类的大权重和频繁类的小权重。这里,我们使用pool类。

from catboost import Pool
pool = Pool(data=X, label=y, cat_features=cat_features)

from sklearn.model_selection import train_test_split

data = train_test_split(X, y, test_size=0.2, random_state=0)
X_train, X_validation, y_train, y_validation = data

train_pool = Pool(
    data=X_train, 
    label=y_train, 
    cat_features=cat_features
)

validation_pool = Pool(
    data=X_validation, 
    label=y_validation, 
    cat_features=cat_features
)

如果我们的标签中有概率,我们可以使用交叉熵,如果我们的标签中有 0 和 1,我们可以使用对数损失(logloss)作为我们的损失函数。如果不指定,模型会自动选择对数损失函数,遇到多类问题,模型会自动选择多类函数。

model = CatBoostClassifier(
    iterations=5,
    learning_rate=0.1,
    # loss_function='CrossEntropy'
)
model.fit(train_pool, eval_set=validation_pool, verbose=False)

print('Model is fitted: {}'.format(model.is_fitted()))
print('Model params:\n{}'.format(model.get_params()))

#输出内容
Model is fitted: True
Model params:
{'iterations': 5, 'learning_rate': 0.1}

#其他设置
model = CatBoostClassifier(
    iterations=15,
#     verbose=5,
)
model.fit(train_pool, eval_set=validation_pool);

model = CatBoostClassifier(
    iterations=50,
    learning_rate=0.5,
    custom_loss=['AUC', 'Accuracy']
)

model.fit(
    train_pool,
    eval_set=validation_pool,
    verbose=False,
    plot=True
);

model_with_early_stop = CatBoostClassifier(
    iterations=200,
    learning_rate=0.5,
    early_stopping_rounds=20
)

model_with_early_stop.fit(
    train_pool,
    eval_set=validation_pool,
    verbose=False,
    plot=True
);

进行交叉验证:

from catboost import cv

params = {
    'loss_function': 'Logloss',
    'iterations': 80,
    'custom_loss': 'AUC',
    'learning_rate': 0.5,
}

cv_data = cv(
    params = params,
    pool = train_pool,
    fold_count=5,
    shuffle=True,
    partition_random_seed=0,
    plot=True,
    verbose=False
)

cv_data.head(10)

基于网格搜索的参数优化:

from sklearn.model_selection import GridSearchCV

param_grid = {
    "iterations": [10,100],
    "learning_rate": [0.01,0.1],
    "depth": [4,7],
    "early_stopping_rounds" : [5,10],
    "depth" : [4,8],
    "l2_leaf_reg": [2,4]
}

clf = CatBoostClassifier(
    cat_features=cat_features, 
    verbose=20
)
grid_search = GridSearchCV(clf, param_grid=param_grid, cv=3)
results = grid_search.fit(X_train, y_train)
results.best_estimator_.get_params()

#输出内容:
{'iterations': 100,
 'learning_rate': 0.1,
 'depth': 8,
 'l2_leaf_reg': 2,
 'verbose': 20,
 'early_stopping_rounds': 5,
 'cat_features': [0, 1, 2, 3, 4, 5, 6, 7, 8]}

early_stopping_rounds:使用此参数,如果在指定的迭代次数(例如连续 50 次迭代)后没有看到改进,我们会提前终止训练。

相关推荐

如何用 coco 数据集训练 Detectron2 模型?

随着最新的Pythorc1.3版本的发布,下一代完全重写了它以前的目标检测框架,新的目标检测框架被称为Detectron2。本教程将通过使用自定义coco数据集训练实例分割模型,帮助你开始使...

CICD联动阿里云容器服务Kubernetes实践之Bamboo篇

本文档以构建一个Java软件项目并部署到阿里云容器服务的Kubernetes集群为例说明如何使用Bamboo在阿里云Kubernetes服务上运行RemoteAgents并在agents上...

Open3D-ML点云语义分割实验【RandLA-Net】

作为点云Open3D-ML实验的一部分,我撰写了文章解释如何使用Tensorflow和PyTorch支持安装此库。为了测试安装,我解释了如何运行一个简单的Python脚本来可视化名为...

清理系统不用第三方工具(系统自带清理软件效果好不?)

清理优化系统一定要借助于优化工具吗?其实,手动优化系统也没有那么神秘,掌握了方法和技巧,系统清理也是一件简单和随心的事。一方面要为每一个可能产生累赘的文件找到清理的方法,另一方面要寻找能够提高工作效率...

【信创】联想开先终端开机不显示grub界面的修改方法

原文链接:【信创】联想开先终端开机不显示grub界面的修改方法...

如意玲珑成熟度再提升,三大发行版支持教程来啦!

前期,我们已分别发布如意玲珑在deepinV23与UOSV20、openEuler24.03发行版的操作指南,本文,我们将为大家详细介绍Ubuntu24.04、Debian12、op...

118种常见的多媒体文件格式(英文简写)

MP4[?mpi?f??]-MPEG-4Part14(MPEG-4第14部分)AVI[e?vi??a?]-AudioVideoInterleave(音视频交错)MOV[m...

密码丢了急上火?码住7种console密码紧急恢复方式!

身为攻城狮的你,...

CSGO丨CS2的cfg指令代码分享(csgo自己的cfg在哪里?config文件位置在哪?)

?...

使用open SSL生成局域网IP地址证书

某些特殊情况下,用户内网访问多可文档管理系统时需要启用SSL传输加密功能,但只有IP,没有域名和证书。这种情况下多可提供了一种免费可行的方式,通过openSSL生成免费证书。此方法生成证书浏览器会提示...

Python中加载配置文件(python怎么加载程序包)

我们在做开发的时候经常要使用配置文件,那么配置文件的加载就需要我们提前考虑,再不使用任何框架的情况下,我们通常会有两种解决办法:完整加载将所有配置信息一次性写入单一配置文件.部分加载将常用配置信息写...

python开发项目,不得不了解的.cfg配置文件

安装软件时,经常会见到后缀为.cfg、.ini的文件,一般我们不用管,只要不删就行。因为这些是程序安装、运行时需要用到的配置文件。但对开发者来说,这种文件是怎么回事就必须搞清了。本文从.cfg文件的创...

瑞芯微RK3568鸿蒙开发板OpenHarmony系统修改cfg文件权限方法

本文适用OpenHarmony开源鸿蒙系统,本次使用的是开源鸿蒙主板,搭载瑞芯微RK3568芯片。深圳触觉智能专注研发生产OpenHarmony开源鸿蒙硬件,包括核心板、开发板、嵌入式主板,工控整机等...

Python9:图像风格迁移-使用阿里的接口

先不多说,直接上结果图。#!/usr/bin/envpython#coding=utf-8importosfromaliyunsdkcore.clientimportAcsClient...

Python带你打造个性化的图片文字识别

我们的目标:从CSV文件读取用户的文件信息,并将文件名称修改为姓名格式的中文名称,进行规范资料整理,从而实现快速对多个文件进行重命名。最终效果:将原来无规律的文件名重命名为以姓名为名称的文件。技术点:...

取消回复欢迎 发表评论:

请填写验证码