百度360必应搜狗淘宝本站头条
当前位置:网站首页 > 编程字典 > 正文

TensorFlow实战MNIST手写数字识别的-代价函数优化准确率98%

toyiye 2024-06-21 12:37 14 浏览 0 评论

MNIST 是一个TensorFlow入门级的计算机视觉数据集,下载MNIST资源包https://download.csdn.net/download/bysjlwdx/10770046

最简单的tensorflow的手写识别模型,这一节我们将会介绍其简单的优化模型。我们会从代价函数,多层感知器,防止过拟合,以及优化器的等几个方面来介绍优化过程。

1.代价函数的优化:

我们可以这样将代价函数理解为真实值与预测值的差距,我们神经网络训练的目的就是调整W,b等参数来让这个代价函数的值最小。上一节我们用到的是二次代价函数:

在TensorFlow中的实现为:loss = tf.reduce_mean(tf.square(y-prediction)),但是这个代价函数会带来一定的问题,比如说刚开始学习的会很慢。我们知道神经网络的学习是通过梯度的反向传播来更新参数W,b的:

但是我们的sigmoid激活函数为:

当z很大的时候,例如在B点时,σ'(z)即改点切线的斜率将会很小,导致W,b的梯度很小,神经网络更新的将会很慢。为了解决这个问题,这一节我们将会引入交叉熵代价函数:

其中C为代价函数,x为样本,y为实际值,a为预测值,n为样本总数。

我们先来观察一下这个代价函数:

当实际值y=1时,C= -1/n *∑ylna, 此时当a->1时,C->0 ,当a->0时C->无穷大

当实际值y=0时,C=-1/n *∑ln(1-a) 此时当a->1时,C->无穷大 ,当a->0时C->0

可以发现当预测值a=实际值y时,这个代价函数将会最小。

接下来我们对其求梯度得:

可以发现其对于W,b的梯度是与σ'(z)无关的,不会因为Z过大引起学习过慢的问题,而且我们发现W,b的梯度与σ(z)-y有关,而这个差值就是预测值与真实值的差值,也就是说当预测值与真实值的偏差很大时,神经网络的更新会很快,当预测值与真实值的偏差很小时,神经网络的更新会减慢,这恰恰符合了我们神经网络的更新策略。因此我们将会用

交叉熵代价函数来代替二次代价函数。

代码如下:

#!/usr/bin/env python
# -*- coding: UTF-8 -*-
import tensorflow as tf
from tensorflow.examples.tutorials.mnist import input_data
# 载入数据集
# 当前路径
mnist = input_data.read_data_sets("MNIST_data", one_hot=True)
# 每个批次的大小
# 以矩阵的形式放进去
batch_size = 100
# 计算一共有多少个批次
n_batch = mnist.train.num_examples // batch_size
# 定义三个placeholder
# 28 x 28 = 784
x = tf.placeholder(tf.float32, [None, 784])
y = tf.placeholder(tf.float32, [None, 10])
keep_prob = tf.placeholder(tf.float32)
# 学习率
lr = tf.Variable(0.001, dtype=tf.float32)
# 创建一个的神经网络
# 输入层784,隐藏层一500,隐藏层二300,输出层10个神经元
# 隐藏层
W1 = tf.Variable(tf.truncated_normal([784, 500], stddev=0.1))
b1 = tf.Variable(tf.zeros([500]) + 0.1)
L1 = tf.nn.tanh(tf.matmul(x, W1) + b1)
L1_drop = tf.nn.dropout(L1, keep_prob)
W2 = tf.Variable(tf.truncated_normal([500, 300], stddev=0.1))
b2 = tf.Variable(tf.zeros([300]) + 0.1)
L2 = tf.nn.tanh(tf.matmul(L1_drop, W2) + b2)
L2_drop = tf.nn.dropout(L2, keep_prob)
W3 = tf.Variable(tf.truncated_normal([300, 10], stddev=0.1))
b3 = tf.Variable(tf.zeros([10]) + 0.1)
prediction = tf.nn.softmax(tf.matmul(L2_drop, W3) + b3)
# 交叉熵代价函数
loss = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(labels=y, logits=prediction))
# 训练
train_step = tf.train.AdamOptimizer(lr).minimize(loss)
# 初始化变量
init = tf.global_variables_initializer()
# 结果存放在一个布尔型列表中
# tf.argmax(y, 1)与tf.argmax(prediction, 1)相同返回True,不同则返回False
# argmax返回一维张量中最大的值所在的位置
correct_prediction = tf.equal(tf.argmax(y, 1), tf.argmax(prediction, 1))
# 求准确率
# tf.cast(correct_prediction, tf.float32) 将布尔型转换为浮点型
accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32))
with tf.Session() as sess:
 sess.run(init)
 # 总共51个周期
 for epoch in range(51):
 # 刚开始学习率比较大,后来慢慢变小
 sess.run(tf.assign(lr, 0.001 * (0.95 ** epoch)))
 # 总共n_batch个批次
 for batch in range(n_batch):
 # 获得一个批次
 batch_xs, batch_ys = mnist.train.next_batch(batch_size)
 sess.run(train_step, feed_dict={x: batch_xs, y: batch_ys, keep_prob: 1.0})
 learning_rate = sess.run(lr)
 # 训练完一个周期后测试数据准确率
 acc = sess.run(accuracy, feed_dict={x: mnist.test.images, y: mnist.test.labels, keep_prob: 1.0})
 print("Iter" + str(epoch) + ", Testing Accuracy" + str(acc) + ", Learning_rate" + str(learning_rate))

运行执行结果:

相关推荐

为何越来越多的编程语言使用JSON(为什么编程)

JSON是JavascriptObjectNotation的缩写,意思是Javascript对象表示法,是一种易于人类阅读和对编程友好的文本数据传递方法,是JavaScript语言规范定义的一个子...

何时在数据库中使用 JSON(数据库用json格式存储)

在本文中,您将了解何时应考虑将JSON数据类型添加到表中以及何时应避免使用它们。每天?分享?最新?软件?开发?,Devops,敏捷?,测试?以及?项目?管理?最新?,最热门?的?文章?,每天?花?...

MySQL 从零开始:05 数据类型(mysql数据类型有哪些,并举例)

前面的讲解中已经接触到了表的创建,表的创建是对字段的声明,比如:上述语句声明了字段的名称、类型、所占空间、默认值和是否可以为空等信息。其中的int、varchar、char和decimal都...

JSON对象花样进阶(json格式对象)

一、引言在现代Web开发中,JSON(JavaScriptObjectNotation)已经成为数据交换的标准格式。无论是从前端向后端发送数据,还是从后端接收数据,JSON都是不可或缺的一部分。...

深入理解 JSON 和 Form-data(json和formdata提交区别)

在讨论现代网络开发与API设计的语境下,理解客户端和服务器间如何有效且可靠地交换数据变得尤为关键。这里,特别值得关注的是两种主流数据格式:...

JSON 语法(json 语法 priority)

JSON语法是JavaScript语法的子集。JSON语法规则JSON语法是JavaScript对象表示法语法的子集。数据在名称/值对中数据由逗号分隔花括号保存对象方括号保存数组JS...

JSON语法详解(json的语法规则)

JSON语法规则JSON语法是JavaScript对象表示法语法的子集。数据在名称/值对中数据由逗号分隔大括号保存对象中括号保存数组注意:json的key是字符串,且必须是双引号,不能是单引号...

MySQL JSON数据类型操作(mysql的json)

概述mysql自5.7.8版本开始,就支持了json结构的数据存储和查询,这表明了mysql也在不断的学习和增加nosql数据库的有点。但mysql毕竟是关系型数据库,在处理json这种非结构化的数据...

JSON的数据模式(json数据格式示例)

像XML模式一样,JSON数据格式也有Schema,这是一个基于JSON格式的规范。JSON模式也以JSON格式编写。它用于验证JSON数据。JSON模式示例以下代码显示了基本的JSON模式。{"...

前端学习——JSON格式详解(后端json格式)

JSON(JavaScriptObjectNotation)是一种轻量级的数据交换格式。易于人阅读和编写。同时也易于机器解析和生成。它基于JavaScriptProgrammingLa...

什么是 JSON:详解 JSON 及其优势(什么叫json)

现在程序员还有谁不知道JSON吗?无论对于前端还是后端,JSON都是一种常见的数据格式。那么JSON到底是什么呢?JSON的定义...

PostgreSQL JSON 类型:处理结构化数据

PostgreSQL提供JSON类型,以存储结构化数据。JSON是一种开放的数据格式,可用于存储各种类型的值。什么是JSON类型?JSON类型表示JSON(JavaScriptO...

JavaScript:JSON、三种包装类(javascript 包)

JOSN:我们希望可以将一个对象在不同的语言中进行传递,以达到通信的目的,最佳方式就是将一个对象转换为字符串的形式JSON(JavaScriptObjectNotation)-JS的对象表示法...

Python数据分析 只要1分钟 教你玩转JSON 全程干货

Json简介:Json,全名JavaScriptObjectNotation,JSON(JavaScriptObjectNotation(记号、标记))是一种轻量级的数据交换格式。它基于J...

比较一下JSON与XML两种数据格式?(json和xml哪个好)

JSON(JavaScriptObjectNotation)和XML(eXtensibleMarkupLanguage)是在日常开发中比较常用的两种数据格式,它们主要的作用就是用来进行数据的传...

取消回复欢迎 发表评论:

请填写验证码