百度360必应搜狗淘宝本站头条
当前位置:网站首页 > 编程字典 > 正文

用简单的 2D CNN 进行 MNIST 数字识别

toyiye 2024-06-21 12:21 9 浏览 0 评论

雷锋网 AI 研习社按:本文为雷锋网字幕组编译的技术博客,原标题 A simple 2D CNN for MNIST digit recognition,作者为 Sambit Mahapatra。

翻译 | 王祎 校对 | 霍雷刚 整理 | 孔令双

对于图像分类任务,当前最先进的架构是卷积神经网络 (CNNs).。无论是面部识别、自动驾驶还是目标检测,CNN 得到广泛使用。在本文中,针对著名的 MNIST 数字识别任务,我们设计了一个以 tensorflow 为后台技术、基于 keras 的简单 2D 卷积神经网络 (CNN) 模型。整个工作流程如下:

1. 准备数据

2. 创建模型并编译

3. 训练模型并评估

4. 将模型存盘以便下次使用

数据集就使用上文所提到的 MNIST 数据集。MNIST 数据集 (Modified National Institute of Standards and Technoloy 数据集) 是一个大型的手写数字(0 到 9)数据集。该数据集包含 大小为 28x28 的图片 7 万张,其中 6 万张训练图片、1 万张测试图片。第一步,加载数据集,这一步可以很容易地通过 keras api 来实现。

import keras

from keras.datasets import mnist

#load mnist dataset

(X_train, y_train), (X_test, y_test) = mnist.load_data #everytime loading data won't be so easy :)

其中,X_train 包含 6 万张 大小为 28x28 的训练图片,y_train 包含这些图片对应的标签。与之类似,X_test 包含了 1 万张大小为 28x28 的测试图片,y_test 为其对应的标签。我们将一部分训练数据可视化一下,来对深度学习模型的目标有一个认识吧。

import matplotlib.pyplot as plt

fig = plt.figure

for i in range(9):

plt.subplot(3,3,i+1)

plt.tight_layout

plt.imshow(X_train[i], cmap='gray', interpolation='none')

plt.title("Digit: {}".format(y_train[i]))

plt.xticks([])

plt.yticks([])

fig

如上所示,左上角图为「5」的图片数据被存在 X_train[0] 中,y_train[0] 中存储其对应的标签「5」。我们的深度学习模型应该能够仅仅通过手写图片预测实际写下的数字。 现在,为了准备数据,我们需要对这些图片做一些诸如调整大小、像素值归一化之类的处理。

#reshaping

#this assumes our data format

#For 3D data, "channels_last" assumes (conv_dim1, conv_dim2, conv_dim3, channels) while

#"channels_first" assumes (channels, conv_dim1, conv_dim2, conv_dim3).

if k.image_data_format == 'channels_first':

X_train = X_train.reshape(X_train.shape[0], 1, img_rows, img_cols)

X_test = X_test.reshape(X_test.shape[0], 1, img_rows, img_cols)

input_shape = (1, img_rows, img_cols)

else:

X_train = X_train.reshape(X_train.shape[0], img_rows, img_cols, 1)

X_test = X_test.reshape(X_test.shape[0], img_rows, img_cols, 1)

input_shape = (img_rows, img_cols, 1)

#more reshaping

X_train = X_train.astype('float32')

X_test = X_test.astype('float32')

X_train /= 255

X_test /= 255

print('X_train shape:', X_train.shape) #X_train shape: (60000, 28, 28, 1)

对图片数据做了必要的处理之后,需要将 y_train 和 y_test 标签数据进行转换,转换成分类的格式。例如,模型构建时,3 应该被转换成向量 [0, 0, 0, 1, 0, 0, 0, 0, 0, 0]。

import keras

#set number of categories

num_category = 10

# convert class vectors to binary class matrices

y_train = keras.utils.to_categorical(y_train, num_category)

y_test = keras.utils.to_categorical(y_test, num_category)

创建模型并编译

数据加载进模型之后,我们需要定义模型结构,并通过优化函数、损失函数和性能指标。

接下来定义的架构为 2 个卷积层,分别在每个卷积层后接续一个池化层,一个全连接层和一个 softmax 层。在每一层卷积层上都会使用多个滤波器来提取不同类型的特征。直观的解释的话,第一个滤波器有助于检测图片中的直线,第二个滤波器有助于检测图片中的圆形,等等。关于每一层技术实现的解释,将会在后续的帖子中进行讲解。如果想要更好的理解每一层的含义,可以参考 http://cs231n.github.io/convolutional-networks/

在最大池化和全连接层之后,在我们的模型中引入 dropout 来进行正则化,用以消除模型的过拟合问题。

##model building

model = Sequential

#convolutional layer with rectified linear unit activation

model.add(Conv2D(32, kernel_size=(3, 3),

activation='relu',

input_shape=input_shape))

#32 convolution filters used each of size 3x3

#again

model.add(Conv2D(64, (3, 3), activation='relu'))

#64 convolution filters used each of size 3x3

#choose the best features via pooling

model.add(MaxPooling2D(pool_size=(2, 2)))

#randomly turn neurons on and off to improve convergence

model.add(Dropout(0.25))

#flatten since too many dimensions, we only want a classification output

model.add(Flatten)

#fully connected to get all relevant data

model.add(Dense(128, activation='relu'))

#one more dropout for convergence' sake :)

model.add(Dropout(0.5))

#output a softmax to squash the matrix into output probabilities

model.add(Dense(num_category, activation='softmax'))

确定模型架构之后,模型需要进行编译。由于这是多类别的分类问题,因此我们需要使用 categorical_crossentropy 作为损失函数。由于所有的标签都带有相似的权重,我们更喜欢使用精确度作为性能指标。AdaDelta 是一个很常用的梯度下降方法。我们使用这个方法来优化模型参数。

#Adaptive learning rate (adaDelta) is a popular form of gradient descent rivaled only by adam and adagrad

#categorical ce since we have multiple classes (10)

model.compile(loss=keras.losses.categorical_crossentropy,

optimizer=keras.optimizers.Adadelta,

metrics=['accuracy'])

训练模型并评估

在定义模型架构和编译模型之后,要使用训练集去训练模型,使得模型可以识别手写数字。这里,我们将使用 X_train 和 y_train 来拟合模型。

batch_size = 128

num_epoch = 10

#model training

model_log = model.fit(X_train, y_train,

batch_size=batch_size,

epochs=num_epoch,

verbose=1,

validation_data=(X_test, y_test))

其中,一个 epoch 表示一次全量训练样例的前向和后向传播。batch_size 就是在一次前向/后向传播过程用到的训练样例的数量。训练输出结果如下:

现在,我们来评估训练得到模型的性能。

score = model.evaluate(X_test, y_test, verbose=0)

print('Test loss:', score[0]) #Test loss: 0.0296396646054

print('Test accuracy:', score[1]) #Test accuracy: 0.9904

测试准确率达到了 99%+,这意味着这个预测模型训练的很成功。如果查看整个训练日志,就会发现随着 epoch 的次数的增多,模型在训练数据和测试数据上的损失和准确率逐渐收敛,最终趋于稳定。

import os

# plotting the metrics

fig = plt.figure

plt.subplot(2,1,1)

plt.plot(model_log.history['acc'])

plt.plot(model_log.history['val_acc'])

plt.title('model accuracy')

plt.ylabel('accuracy')

plt.xlabel('epoch')

plt.legend(['train', 'test'], loc='lower right')

plt.subplot(2,1,2)

plt.plot(model_log.history['loss'])

plt.plot(model_log.history['val_loss'])

plt.title('model loss')

plt.ylabel('loss')

plt.xlabel('epoch')

plt.legend(['train', 'test'], loc='upper right')

plt.tight_layout

fig

将模型存盘以便下次使用

现在需要将训练过的模型进行序列化。模型的架构或者结构保存在 json 文件,权重保存在 hdf5 文件。

#Save the model

# serialize model to JSON

model_digit_json = model.to_json

with open("model_digit.json", "w") as json_file:

json_file.write(model_digit_json)

# serialize weights to HDF5

model.save_weights("model_digit.h5")

print("Saved model to disk")

模型被保存后,可以被重用,也可以很方便地移植到其它环境中使用。在以后的帖子中,我们将会演示如何在生产环境中部署这个模型。

享受深度学习吧!

参考资料

Getting started with the Keras Sequential model

CS231n Convolutional Neural Networks for Visual Recognition

sambit9238/Deep-Learning

雷锋网字幕组编译

原文链接:https://towardsdatascience.com/a-simple-2d-cnn-for-mnist-digit-recognition-a998dbc1e79a

号外号外~

一个专注于

AI技术发展和AI工程师成长的求知求职社区

诞生啦!

欢迎大家访问以下链接或者扫码体验

https://club.leiphone.com/page/home

相关推荐

为何越来越多的编程语言使用JSON(为什么编程)

JSON是JavascriptObjectNotation的缩写,意思是Javascript对象表示法,是一种易于人类阅读和对编程友好的文本数据传递方法,是JavaScript语言规范定义的一个子...

何时在数据库中使用 JSON(数据库用json格式存储)

在本文中,您将了解何时应考虑将JSON数据类型添加到表中以及何时应避免使用它们。每天?分享?最新?软件?开发?,Devops,敏捷?,测试?以及?项目?管理?最新?,最热门?的?文章?,每天?花?...

MySQL 从零开始:05 数据类型(mysql数据类型有哪些,并举例)

前面的讲解中已经接触到了表的创建,表的创建是对字段的声明,比如:上述语句声明了字段的名称、类型、所占空间、默认值和是否可以为空等信息。其中的int、varchar、char和decimal都...

JSON对象花样进阶(json格式对象)

一、引言在现代Web开发中,JSON(JavaScriptObjectNotation)已经成为数据交换的标准格式。无论是从前端向后端发送数据,还是从后端接收数据,JSON都是不可或缺的一部分。...

深入理解 JSON 和 Form-data(json和formdata提交区别)

在讨论现代网络开发与API设计的语境下,理解客户端和服务器间如何有效且可靠地交换数据变得尤为关键。这里,特别值得关注的是两种主流数据格式:...

JSON 语法(json 语法 priority)

JSON语法是JavaScript语法的子集。JSON语法规则JSON语法是JavaScript对象表示法语法的子集。数据在名称/值对中数据由逗号分隔花括号保存对象方括号保存数组JS...

JSON语法详解(json的语法规则)

JSON语法规则JSON语法是JavaScript对象表示法语法的子集。数据在名称/值对中数据由逗号分隔大括号保存对象中括号保存数组注意:json的key是字符串,且必须是双引号,不能是单引号...

MySQL JSON数据类型操作(mysql的json)

概述mysql自5.7.8版本开始,就支持了json结构的数据存储和查询,这表明了mysql也在不断的学习和增加nosql数据库的有点。但mysql毕竟是关系型数据库,在处理json这种非结构化的数据...

JSON的数据模式(json数据格式示例)

像XML模式一样,JSON数据格式也有Schema,这是一个基于JSON格式的规范。JSON模式也以JSON格式编写。它用于验证JSON数据。JSON模式示例以下代码显示了基本的JSON模式。{"...

前端学习——JSON格式详解(后端json格式)

JSON(JavaScriptObjectNotation)是一种轻量级的数据交换格式。易于人阅读和编写。同时也易于机器解析和生成。它基于JavaScriptProgrammingLa...

什么是 JSON:详解 JSON 及其优势(什么叫json)

现在程序员还有谁不知道JSON吗?无论对于前端还是后端,JSON都是一种常见的数据格式。那么JSON到底是什么呢?JSON的定义...

PostgreSQL JSON 类型:处理结构化数据

PostgreSQL提供JSON类型,以存储结构化数据。JSON是一种开放的数据格式,可用于存储各种类型的值。什么是JSON类型?JSON类型表示JSON(JavaScriptO...

JavaScript:JSON、三种包装类(javascript 包)

JOSN:我们希望可以将一个对象在不同的语言中进行传递,以达到通信的目的,最佳方式就是将一个对象转换为字符串的形式JSON(JavaScriptObjectNotation)-JS的对象表示法...

Python数据分析 只要1分钟 教你玩转JSON 全程干货

Json简介:Json,全名JavaScriptObjectNotation,JSON(JavaScriptObjectNotation(记号、标记))是一种轻量级的数据交换格式。它基于J...

比较一下JSON与XML两种数据格式?(json和xml哪个好)

JSON(JavaScriptObjectNotation)和XML(eXtensibleMarkupLanguage)是在日常开发中比较常用的两种数据格式,它们主要的作用就是用来进行数据的传...

取消回复欢迎 发表评论:

请填写验证码