百度360必应搜狗淘宝本站头条
当前位置:网站首页 > 编程字典 > 正文

使用Scikit-Learn库对Keras模型进行超参数调整

toyiye 2024-07-08 23:07 11 浏览 0 评论

Keras是用于Python编程语言的神经网络库,能够与Theano,R或TensorFlow等许多深度学习工具一起运行,并允许快速迭代以进行神经网络的实验或原型设计。

无论您是在Keras中对神经网络模型进行原型设计以了解其将如何执行所需任务,还是对已构建和测试的模型进行微调,都需要为机器学习模型考虑许多参数。这些机器学习模型参数称为超参数。在层中使用的激活函数就是超参数的的示例。机器学习模型中的层数,每层神经元数或卷积神经网络中核的大小都可以视为超级参数。

超参数没有固定公式,不同的问题将需要不同的方法。更改模型的每个参数可能会影响其性能,只有实验才能确定哪种组合最适合您的模型和数据。

在本文中,我们将研究使用机器学习库Scikit-Learn执行超参数调整以优化Keras模型所需的步骤。我们将构建一个简单的神经网络,并使用Scikit-Learn库中的RandomizedSearchCV对象寻找最佳优化器、批量大小和激活。

准备工作

我们将在示例中使用的库是TensorFlow,其中包括Keras和Scikit Learn。

from tensorflow.keras import Sequential
from tensorflow.keras.layers import Dense,Flatten
from tensorflow.keras.datasets import mnist
from tensorflow.keras.wrappers.scikit_learn import KerasClassifier
from sklearn.model_selection import RandomizedSearchCV

我们还将使用numpy和matplotlib库:

import numpy as np
import matplotlib.pyplot as plt

准备数据

首先,让我们使用一个数据集,对其进行格式化并构建我们的机器学习模型。在这里,我们进行归一化并打印其形状以确保我们为模型使用正确的输入:

(X_train, y_trn), (X_test, y_tst) = mnist.load_data()
X_trn = X_train[..., np.newaxis].astype(np.float32) / 255.
X_tst = X_test[..., np.newaxis].astype(np.float32) / 255.
print(X_train.shape,y_trn.shape)
print(X_test.shape,y_tst.shape)

mnist数据集是一组28x28像素的手写数字图片。
我们的数据如下所示:

def preview(data,result):
    """Shows 12 elements of picture dataset"""
    fig = plt.figure()
    for i in range(12):
        plt.subplot(2,6,i+1)
        plt.imshow(data[i], interpolation='none')
        plt.title("label:{}".format(result[i]))
        plt.xticks([])
        plt.yticks([])
preview(X_train[12:],y_trn[12:])

建立模型

为了使用scikit-learn调整Keras模型的参数,我们需要能够使用不同的参数重建模型。为此,我们创建一个函数来基于我们的超参数构建模型:

def build_model(var_activation='relu',var_optimizer='adam'):
  """ Uses arguments to build Keras model. """
  model = Sequential()
  model.add(Flatten(input_shape=[28, 28, 1]))
  model.add(Dense(64,activation=var_activation))
  model.add(Dense(32,activation=var_activation))
  model.add(Dense(16,activation=var_activation))
  model.add(Dense(10,activation='softmax'))
  model.compile(loss="sparse_categorical_crossentropy",
                optimizer=var_optimizer,
                metrics=["accuracy"])
  return model

这是我们的模型在默认参数下的样子:

model_default = build_model()
model_default.summary()

设置变量

我们想使用Adam算法和随机梯度下降来测试模型的性能,并测试不同层的激活函数和批量大小来训练模型。让我们创建参数列表并将它们存储为字典。字典中的键是在我们的模型中使用的变量的名称:

_activations=['tanh','relu','selu']
_optimizers=['sgd','adam']
_batch_size=[16,32,64]
params=dict(var_activation=_activations,
            var_optimizer=_optimizers,
            batch_size=_batch_size)
print(params)

注意,' batch_size '不是build_model函数中的变量,而是.fit()调用中稍后将使用的变量,以训练我们创建的模型。

根据Keras模型创建scikit学习估算器

现在我们有了数据,构建模型的功能以及要测试的参数,我们可以使用sklearn库根据我们的函数和超参数测试不同的模型。我们可以使用sklearn.model_selection模块中的GridSearchCV或RandomizedSearchCV对象来迭代超参数的不同组合,并输出得分最高的模型。GridSearchCV对象将遍历超参数的所有可能组合,而RandomizedSearchCV对象将随机采样许多可能的组合以训练模型。尽管使用随机搜索可能并不总是提供最佳的可能模型,但由随机搜索要快得多,资源消耗也少得多,这使得随机模型搜索对于测试和原型设计非常有用。要使用RandomizedSearchCV,我们首先需要使我们的Keras模型与sklearn库兼容,我们将对scikitlearn使用Keras包装器:KerasClassifier。

model = KerasClassifier(build_fn=build_model,epochs=4,batch_size=16)

在拟合我们的RandomizedSearch对象之前,我们使用numpy.random.seed()设置随机种子。将种子设置为随机数生成器将使我们的模型权重初始化与每次迭代相同,从而使我们的搜索更有意义。但是,如果我们的超参数包含层数或层中节点数,则将无济于事,因为我们将比较完全不同的模型。

np.random.seed(42)

使用RandomizedSearchCV

创建KerasClassifier后,我们将创建RandomizedSearchCV对象,并使用.fit()方法开始搜索最佳模型。RandomizedSearchCV允许我们使用参数n_iter明确控制尝试的组合数量。

rscv = RandomizedSearchCV(model, param_distributions=params, cv=3,     n_iter=10)
rscv_results = rscv.fit(X_trn,y_trn)

这是我们搜索的结果:

print('Best score is: {} using {}'.format(rscv_results.best_score_,
rscv_results.best_params_))

结论

超参数调优可用于微调所选模型,或搜索最适合该任务的模型。它还可以帮助评估模型的学习速度。上面的方法可以进一步扩展,包括使用来自scikit-learn库的GridSearchCV对象进行更详尽的搜索,或者为模型的结构添加参数,如层数。可以添加回调以防止过拟合测试模型。

相关推荐

为何越来越多的编程语言使用JSON(为什么编程)

JSON是JavascriptObjectNotation的缩写,意思是Javascript对象表示法,是一种易于人类阅读和对编程友好的文本数据传递方法,是JavaScript语言规范定义的一个子...

何时在数据库中使用 JSON(数据库用json格式存储)

在本文中,您将了解何时应考虑将JSON数据类型添加到表中以及何时应避免使用它们。每天?分享?最新?软件?开发?,Devops,敏捷?,测试?以及?项目?管理?最新?,最热门?的?文章?,每天?花?...

MySQL 从零开始:05 数据类型(mysql数据类型有哪些,并举例)

前面的讲解中已经接触到了表的创建,表的创建是对字段的声明,比如:上述语句声明了字段的名称、类型、所占空间、默认值和是否可以为空等信息。其中的int、varchar、char和decimal都...

JSON对象花样进阶(json格式对象)

一、引言在现代Web开发中,JSON(JavaScriptObjectNotation)已经成为数据交换的标准格式。无论是从前端向后端发送数据,还是从后端接收数据,JSON都是不可或缺的一部分。...

深入理解 JSON 和 Form-data(json和formdata提交区别)

在讨论现代网络开发与API设计的语境下,理解客户端和服务器间如何有效且可靠地交换数据变得尤为关键。这里,特别值得关注的是两种主流数据格式:...

JSON 语法(json 语法 priority)

JSON语法是JavaScript语法的子集。JSON语法规则JSON语法是JavaScript对象表示法语法的子集。数据在名称/值对中数据由逗号分隔花括号保存对象方括号保存数组JS...

JSON语法详解(json的语法规则)

JSON语法规则JSON语法是JavaScript对象表示法语法的子集。数据在名称/值对中数据由逗号分隔大括号保存对象中括号保存数组注意:json的key是字符串,且必须是双引号,不能是单引号...

MySQL JSON数据类型操作(mysql的json)

概述mysql自5.7.8版本开始,就支持了json结构的数据存储和查询,这表明了mysql也在不断的学习和增加nosql数据库的有点。但mysql毕竟是关系型数据库,在处理json这种非结构化的数据...

JSON的数据模式(json数据格式示例)

像XML模式一样,JSON数据格式也有Schema,这是一个基于JSON格式的规范。JSON模式也以JSON格式编写。它用于验证JSON数据。JSON模式示例以下代码显示了基本的JSON模式。{"...

前端学习——JSON格式详解(后端json格式)

JSON(JavaScriptObjectNotation)是一种轻量级的数据交换格式。易于人阅读和编写。同时也易于机器解析和生成。它基于JavaScriptProgrammingLa...

什么是 JSON:详解 JSON 及其优势(什么叫json)

现在程序员还有谁不知道JSON吗?无论对于前端还是后端,JSON都是一种常见的数据格式。那么JSON到底是什么呢?JSON的定义...

PostgreSQL JSON 类型:处理结构化数据

PostgreSQL提供JSON类型,以存储结构化数据。JSON是一种开放的数据格式,可用于存储各种类型的值。什么是JSON类型?JSON类型表示JSON(JavaScriptO...

JavaScript:JSON、三种包装类(javascript 包)

JOSN:我们希望可以将一个对象在不同的语言中进行传递,以达到通信的目的,最佳方式就是将一个对象转换为字符串的形式JSON(JavaScriptObjectNotation)-JS的对象表示法...

Python数据分析 只要1分钟 教你玩转JSON 全程干货

Json简介:Json,全名JavaScriptObjectNotation,JSON(JavaScriptObjectNotation(记号、标记))是一种轻量级的数据交换格式。它基于J...

比较一下JSON与XML两种数据格式?(json和xml哪个好)

JSON(JavaScriptObjectNotation)和XML(eXtensibleMarkupLanguage)是在日常开发中比较常用的两种数据格式,它们主要的作用就是用来进行数据的传...

取消回复欢迎 发表评论:

请填写验证码