百度360必应搜狗淘宝本站头条
当前位置:网站首页 > 编程字典 > 正文

常见隐藏单元(代码篇)|机器学习你会遇到的“坑”

toyiye 2024-07-09 22:48 13 浏览 0 评论




我们在上一节的《常见隐藏单元》(理论篇)中基于阶跃函数和线性函数逐渐改进了隐藏单元,比如sigmoid函数改进了阶跃函数不光滑以及梯度几乎为零的缺点,tanh又改进了sigmoid函数非零中心的问题,反正切函数又改进了tanh函数过早饱和的缺点,LeakyReLU改进了ReLU的死亡问题,PReLU则继续把原本的超参数变为一个可训练的参数,maxout单元又继续增强了激活函数的灵活性。

我们选取MNIST数据集作为材料,这是一个手写数字识别的任务,在机器学习和深度学习测试中被广泛使用,在keras中,我们只需要使用一条命令就可以将数据下载到本地。我们要验证不同的激活函数对于神经网络性能的影响,如果激活函数具有梯度消失的缺点,那么我们就将网络搭的尽可能深,从理论上来说,训练过程中梯度流从一层流入更深的层中,问题才会变得更严重。

需要注意的是,对于MNIST而言,每一张图片都是灰度图像,我们传入神经网络的是像素值,一个(28,28)的图像,我们需要设置输入维数为28*28;同时为了避免极端值对收敛和表示的影响,我们需要对图片做归一化,将所有像素值缩放到一个范围内;再者,我们需要对target做one-hot encoding,比如标签"5",经过one-hot encoding就会变为[0,0,0,0,0,1,0,0,0,0],因为原本的标签是有序的,不同的标签之间存在大小关系,比如"4"似乎比“5”更接近“3”,但这样是不合理的,one-hot encoding则会取消这样的序关系,使得距离(相似性)计算更加合理。

我们利用keras来进行上述的过程:

import numpy as np

from keras.datasets import mnist

from keras.utils import to_categorical

(X_train,y_train),(X_test,y_test)=mnist.load_data()

train_labels =to_categorical(y_train)

test_labels = to_categorical(y_test)

X_train_normal = X_train.reshape(60000,28*28)

X_train_normal = X_train_normal.astype('float32') /255

X_test_normal = X_test.reshape(10000, 28*28)

X_test_normal = X_test_normal.astype('float32') /255


我们获得了6万个训练样本,还有1万个测试样本,并分别做了one-hot encoding和归一化。我们接下来使用keras的Sequential模式来搭建网络,当网络简单而且结构较深的时候,这样的方式更加高效:

from keras import models

from keras.layers import Dense

from keras import optimizers

def normal_model(a):

model=models.Sequential()

model.add(Dense(512,activation=a,input_shape=(28*28,)))

model.add(Dense(256,activation=a))

model.add(Dense(128,activation=a))

model.add(Dense(64,activation=a))

model.add(Dense(10,activation='softmax'))

model.compile(optimizer=optimizers.SGD(momentum=0.9,nesterov=True),\

loss='categorical_crossentropy',metrics=['accuracy'])

return(model)


此处,我们引入了一个参数a用来指定激活函数的类型,因为keras的激活函数可以通过名字的简单的指定,我们将输出单元设置为softmax,这是一个与交叉熵有关的输出单元,我们使用了带有nestrov动量的随机梯度下降的优化算法,交叉熵作为Loss,准确率作为性能度量,函数最后会返回一个编译好的包含4个隐层的神经网络。

我们先来选取sigmoid函数作为激活函数,获取训练结果,保存在变量中:

model_1=normal_model('sigmoid')

his=model_1.fit(X_train_normal,train_labels,\

batch_size=128,\

validation_data=(X_test_normal,test_labels),\

verbose=1,epochs=10)


然后,我们可以绘制Loss随着epochs变化的图像:

import matplotlib.pyplot as plt

import seaborn as sns

sns.set(style='whitegrid')

plt.plot(range(20),his.history['val_loss'],label='validation loss')

plt.plot(range(20),his.history['loss'],label=' train loss')

plt.title('Loss')

plt.legend()

plt.show()




如图,采用sigmoid作为激活函数,Loss在第19个epochs时才会收敛,loss的最小值为0.25。

我们将激活函数换为Tanh,并重复上述步骤,预想到零中心化的激活会使得迭代更新的速度更快:

......

model_1=normal_model('tanh')

......



如图,tanh函数比起sigmoid函数,Loss下降的非常快,而且可以将Loss降到零附近。

我们还可以继续更换Softsign、ReLU、ELU来观察是否具有更快的速度,我们只需要把每次运行的结果存起来,然后在同一张图里观察其效果:

w={}

for a in ['sigmoid','tanh','softsign','relu','elu']:

model_1=normal_model(a)

his=model_1.fit(X_train_normal,train_labels,\

batch_size=128,validation_data=(X_test_normal,test_labels),\

verbose=1,epochs=20)

w[a]=his.history

sns.set(style='whitegrid')

plt.figure()

for a in w.keys():

plt.plot(range(20),w[a]['val_loss'],label=a)

plt.title('model Loss')

plt.legend()

plt.show()




如图,sigmoid是表现最差的,其余的损失函数看起来不相上下。

我们去掉sigmoid函数,对其他几个做详细对比:



如图,tanh,softsign,relu,elu函数均在第8个epochs之后就收敛到了很低的水平,从图中可以看出,relu是收敛最快的,但最终的收敛的Loss却比softsign和elu都要大。

但训练收敛的更快只能说明更易于学习,并不能说明学习的效果好,即准确率高,因为激活函数的不同类型可能具备不同的表达能力,使得在神经网络结构大体一致的情况下,表示能力可能会有差别,我们可以在20个epochs之后获得其最终的准确率:

plt.figure()

f=[]

g=[]

for a in w.keys():

f.append(a)

g.append(w[a]['val_acc'][-1])

sns.barplot(f,g)



如图,另外四种的准确率要优于sigmoid函数,但这四种似乎差别不大,鉴于此,我们可以说不同激活函数对于复杂的神经网络表示能力上的差别很微小,所以,我们有把握说sigmoid函数通过更多迭代步骤或许也可以达到与其他激活函数相媲美的效果。

如果我们想使用可以近似任意凸函数的maxout单元,那么就可以使用keras中的MaxoutDense层,注意,第一个数字表示输出的维数,与Dense层类似,第二个参数nb_feature,表示maxout的分段函数的段数,也就是《常见隐藏单元理论篇》中的k。

from keras.layers import MaxoutDense

def maxout_model():

model=models.Sequential()

model.add(MaxoutDense(512,nb_feature=3,input_shape=(28*28,)))

model.add(MaxoutDense(256,nb_feature=3))

model.add(MaxoutDense(128,nb_feature=3))

model.add(MaxoutDense(64,nb_feature=3))

model.add(Dense(10,activation='softmax'))

model.compile(optimizer=optimizers.SGD(momentum=0.9,nesterov=True),\

loss='categorical_crossentropy',\

metrics=['accuracy'])

return(model)


此外,如果我们还有一些新奇的点子,但在keras原本的激活函数中未被包含,那么就需要灵活使用后端来自定义我们需要的函数,比如我们想使用余弦函数,只需要利用后端定义好,然后将其加入我们预先准备好的神经网络框架:

from keras import backend as K

def cos(x):

return K.cos(x)

def normal_model(a):

model=models.Sequential()

model.add(Dense(512,activation=a,input_shape=(28*28,)))

model.add(Dense(256,activation=a))

model.add(Dense(128,activation=a))

model.add(Dense(64,activation=a))

model.add(Dense(10,activation='softmax'))

model.compile(optimizer=optimizers.SGD(momentum=0.9,nesterov=True),\

loss='categorical_crossentropy',\

metrics=['accuracy'])

return(model)

model_1=normal_model(cos)

.....


然后就可以像上述的步骤一样进行训练。


读芯君开扒

课堂TIPS

? 在keras中,添加激活函数有两种途径,一种是在现有层中指定激活函数的类型,另一种是在现有层的基础上继续添加激活层,这两者方法是等价的,但对神经网络结构不熟悉的读者推荐使用第二种。

? 我们在这里使用了softmax输出单元以及交叉熵作为损失函数,以及上一节提到的,linear输出单元。输出单元和损失函数的变更是神经网络非常重要的一环,我们将在下一节详细讲解。




作者:唐僧不用海飞丝

如需转载,请后台留言,遵守转载规范

相关推荐

为何越来越多的编程语言使用JSON(为什么编程)

JSON是JavascriptObjectNotation的缩写,意思是Javascript对象表示法,是一种易于人类阅读和对编程友好的文本数据传递方法,是JavaScript语言规范定义的一个子...

何时在数据库中使用 JSON(数据库用json格式存储)

在本文中,您将了解何时应考虑将JSON数据类型添加到表中以及何时应避免使用它们。每天?分享?最新?软件?开发?,Devops,敏捷?,测试?以及?项目?管理?最新?,最热门?的?文章?,每天?花?...

MySQL 从零开始:05 数据类型(mysql数据类型有哪些,并举例)

前面的讲解中已经接触到了表的创建,表的创建是对字段的声明,比如:上述语句声明了字段的名称、类型、所占空间、默认值和是否可以为空等信息。其中的int、varchar、char和decimal都...

JSON对象花样进阶(json格式对象)

一、引言在现代Web开发中,JSON(JavaScriptObjectNotation)已经成为数据交换的标准格式。无论是从前端向后端发送数据,还是从后端接收数据,JSON都是不可或缺的一部分。...

深入理解 JSON 和 Form-data(json和formdata提交区别)

在讨论现代网络开发与API设计的语境下,理解客户端和服务器间如何有效且可靠地交换数据变得尤为关键。这里,特别值得关注的是两种主流数据格式:...

JSON 语法(json 语法 priority)

JSON语法是JavaScript语法的子集。JSON语法规则JSON语法是JavaScript对象表示法语法的子集。数据在名称/值对中数据由逗号分隔花括号保存对象方括号保存数组JS...

JSON语法详解(json的语法规则)

JSON语法规则JSON语法是JavaScript对象表示法语法的子集。数据在名称/值对中数据由逗号分隔大括号保存对象中括号保存数组注意:json的key是字符串,且必须是双引号,不能是单引号...

MySQL JSON数据类型操作(mysql的json)

概述mysql自5.7.8版本开始,就支持了json结构的数据存储和查询,这表明了mysql也在不断的学习和增加nosql数据库的有点。但mysql毕竟是关系型数据库,在处理json这种非结构化的数据...

JSON的数据模式(json数据格式示例)

像XML模式一样,JSON数据格式也有Schema,这是一个基于JSON格式的规范。JSON模式也以JSON格式编写。它用于验证JSON数据。JSON模式示例以下代码显示了基本的JSON模式。{"...

前端学习——JSON格式详解(后端json格式)

JSON(JavaScriptObjectNotation)是一种轻量级的数据交换格式。易于人阅读和编写。同时也易于机器解析和生成。它基于JavaScriptProgrammingLa...

什么是 JSON:详解 JSON 及其优势(什么叫json)

现在程序员还有谁不知道JSON吗?无论对于前端还是后端,JSON都是一种常见的数据格式。那么JSON到底是什么呢?JSON的定义...

PostgreSQL JSON 类型:处理结构化数据

PostgreSQL提供JSON类型,以存储结构化数据。JSON是一种开放的数据格式,可用于存储各种类型的值。什么是JSON类型?JSON类型表示JSON(JavaScriptO...

JavaScript:JSON、三种包装类(javascript 包)

JOSN:我们希望可以将一个对象在不同的语言中进行传递,以达到通信的目的,最佳方式就是将一个对象转换为字符串的形式JSON(JavaScriptObjectNotation)-JS的对象表示法...

Python数据分析 只要1分钟 教你玩转JSON 全程干货

Json简介:Json,全名JavaScriptObjectNotation,JSON(JavaScriptObjectNotation(记号、标记))是一种轻量级的数据交换格式。它基于J...

比较一下JSON与XML两种数据格式?(json和xml哪个好)

JSON(JavaScriptObjectNotation)和XML(eXtensibleMarkupLanguage)是在日常开发中比较常用的两种数据格式,它们主要的作用就是用来进行数据的传...

取消回复欢迎 发表评论:

请填写验证码