百度360必应搜狗淘宝本站头条
当前位置:网站首页 > 编程字典 > 正文

TensorFlow 数据读取

toyiye 2024-06-21 12:37 9 浏览 0 评论

  • 一、使用 placeholder + feed_dict 传入数据
  • 二、使用 TFRecords 统一输入数据的格式
  • 0、TFRecords 数据格式的优缺点
  • 1、将数据转换为 .tfrecords 文件
  • a、获得图片的保存路径和标签
  • b、指定编码函数
  • c、将图片数据和标签(或其它需要需要保存的数据)都转成 TFRecods 格式
  • 2、读取并解码 .tfrecords 文件并生成 batch
  • a、指定想要读取的 .tfrecords 文件列表
  • b、创建一个输入文件名队列来维护输入文件列表
  • c、读取并解码
  • 3、将 batch 数据喂入计算图并开始训练、验证、测试等
  • 三、参考资料

一、使用 placeholder + feed_dict 传入数据

placeholder 是 Tensorflow 中的占位符必须要指定将传给该占位符的值的数据类型 dtype ,一般为 tf.float32 形式;然后通过 sess.run() 的可选参数 feed_dict 为给占位符喂入实际的数据.eg: sess.run(***, feed_dict={input: **})

input = tf.placeholder(tf.float32, shape=[2], name="my_input")

  • dtype:指定了将传给该占位符的值的数据类型。该参数是必须指定的,因为需要确保不出现类型不匹配的错误
  • shape:指定了所要传入的 Tensor 对象的形状,shape 参数的默认值为None,表示可接收任意形状的Tensor对象
  • name:与任何 op 一样,也可在 tf.placeholder 中指定一个 name 标识符
input1 = tf.placeholder(tf.float32)
input2 = tf.placeholder(tf.float32)
output = tf.add(input1, input2)
with tf.Session() as sess:
 print sess.run([output], feed_dict={input1:[7.], input2:[2.]})
>>> [array([ 9.], dtype=float32)]
1
2
3
4
5
6
7
8

Note:在 shape 的一个维度上使用 None 可以方便的使用不同 batch 的大小。在训练时,把数据分成比较小的 batch,但在测试时,可以一次使用全部的数据。但要注意,当数据集比较大时,将大量数据放入一个 batch 可能导致内存溢出


二、使用 TFRecords 统一输入数据的格式

0、TFRecords 数据格式的优缺点

  • TFRecord 文件中的数据都是通过 tf.train.Example Protocol Buffer 的格式存储的,它的优缺点如下所示:
  • 优点:
  • 可以统一不同的原始数据格式
  • 更加有效的管理不同的属性、更好的利用内存、更方便的复制和移动
  • 缺点:
  • 转换过后 tfrecords 文件会占用较大内存

1、将数据转换为 .tfrecords 文件

a、获得图片的保存路径和标签

# 获得图片的保存路径和标签,以便后面的读取和转换
def get_file(file_dir):
 '''Get full image directory and corresponding labels
 Args:
 file_dir: file directory
 Returns:
 images: image directories, list, string
 labels: label, list, int
 '''
1
2
3
4
5
6
7
8
9

b、指定编码函数

tf.train.Example的数据结构中包含了一个从属性到取值的字典。

  • 属性名称(feature name)为一个字符串
  • 属性的取值(feature value)可以为字符串列表(BytesList)、实数列表(FloatList)或者整数列表(Int64List),通过以下函数编码为Example proto形式的返回值
# Wrapper for inserting int64 features into Example proto
def _int64_feature(value):
 return tf.train.Feature(int64_list=tf.train.Int64List(value=[value]))
# Wrapper for inserting bytes features into Example proto
def _bytes_feature(value):
 return tf.train.Feature(bytes_list=tf.train.BytesList(value=[value]))
1
2
3
4
5
6
7

c、将图片数据和标签(或其它需要需要保存的数据)都转成 TFRecods 格式

  • 指定转换数据格式后的保存路径和文件名称
  • 创建一个实例对象 writer,用于后面序列化数据的写入
  • 将所有数据按照 tf.train.Example Protocol Buffer 的格式存储
  • 取得图片的样本总数
  • 循环读取图片和标签的内容:将图片内容转换为字符串型,当有多个标签时,应将多标签内容也转换为字符串型
  • 使用编码函数将一个样例的所有数据(图片和标签内容等)转换为Example Protocol Buffer
  • 调用实例对象 writer 的 write 方法将序列化后的 Example Protocol Buffer 写入 TFRecords 文件
  • 所有样本数据都转换完毕时,调用实例对象 writer 的 close 方法结束写入过程
import tensorflow as tf
import numpy as np
import os
import skimage.io as io
# 将图片数据和标签(或者其它需要需要保存的数据)都转成 TFRecods 格式的数据
def convert_to_tfrecord(images, labels, save_dir, name):
 '''convert all images and labels to one tfrecord file.
 Args:
 images: list of image directories, string type
 labels: list of labels, int type
 save_dir: the directory to save tfrecord file, e.g.: '/home/folder1/'
 name: the name of tfrecord file, string type, e.g.: 'train'
 Return:
 no return
 '''
 # 指定数据转换格式后的保存路径和名称
 filename = os.path.join(save_dir, name + '.tfrecords')
 # 创建一个实例对象 writer,用于后面序列化数据的写入
 writer = tf.python_io.TFRecordWriter(filename)
 # 取得图片的样本总数
 n_samples = len(labels)
 print('\nTransform start......')
 # 将所有数据(包括标签等)按照 tf.train.Example Protocol Buffer 的格式存储
 for i in np.arange(n_samples):
 try:
 image = io.imread(images[i]) # read a image, returned image type must be array!
 image_raw = image.tostring() # 将图片矩阵转化为字符串,tobytes同理
 label = int(labels[i]) # 当单个label为字符串时,需要将其转换为int型
 # 创建tf.train.Example 协议内存块,把标签、图片数据作为特定字段存入(数据类型转换)
 example = tf.train.Example(features=tf.train.Features(feature={
 'label': _int64_feature(label),
 'image_raw': _bytes_feature(image_raw)}))
 # 调用实例对象 writer 的 write 方法将序列化后的 example 协议内存块写入 TFRecord 文件 
 writer.write(example.SerializeToString())
 # 跳过不能读取的图片 
 except IOError as e:
 print('Could not read:', images[i])
 print('error: %s' % e)
 print('Skip it!\n')
 # 调用实例对象 writer 的 close 方法结束写入过程
 writer.close()
 print('Transform done!')
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51

2、读取并解码 .tfrecords 文件并生成 batch

A typical pipeline for reading records from files has the following stages:

  • The list of filenames
  • Filename queue
  • Optional filename shuffling
  • Optional epoch limit
  • A Reader for the file format
  • A decoder for a record read by the reader
  • Optional preprocessing
  • Example queue

a、指定想要读取的 .tfrecords 文件列表

# 直接指定文件列表
filenames = ['/path/to/train_dataset1.tfrecords', '/path/to/train_dataset2.tfrecords']
# 通过 tf.train.match_filenames_once 函数获取文件列表
filenames = tf.train.match_filenames_once(os.path.join(FLAGS.data_dir, 'train_*.tfrecords'))
# 通过 python 中的 glob 模块获取文件列表
filenames = glob.glob(os.path.join(FLAGS.data_dir, 'train_*.tfrecords'))
1
2
3
4
5
6
7
8

b、创建一个输入文件名队列来维护输入文件列表

  • 通过tf.train.string_input_producer(filenames, shuffle=True, num_epochs=None)函数来产生输入文件名队列
  • 可参考 十图详解tensorflow数据读取机制 进行理解,如下图所示,当系统检测到了“结束”,就会自动抛出一个异常(OutOfRange)外部捕捉到这个异常后就可以结束程序了,不过个人理解这里A、B、C 应该为.tfrecords格式的文件,即类似上面filenames中的内容


tf.train.string_input_producer(
 string_tensor,
 num_epochs=None,
 shuffle=True,
 seed=None,
 capacity=32,
 shared_name=None,
 name=None,
 cancel_op=None
)
# 参数
string_tensor: A 1-D string tensor with the strings to produce, 如上面的filenames
num_epochs: An integer (optional). If specified, string_input_producer produces each string from string_tensor num_epochs times before generating an OutOfRange error. If not specified, string_input_producer can cycle through the strings in string_tensor an unlimited number of times.
shuffle: Boolean. If true, the strings are randomly shuffled within each epoch.
capacity: An integer. Sets the queue capacity.
# 返回值
A queue with the output strings. A QueueRunner for the Queue is added to the current Graph's QUEUE_RUNNER collection.
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23

c、读取并解码

  • 创建一个实例对象 reader,用于读取 .tfrecords中的样例
  • 调用实例对象 reader 的 read 方法,读取文件名队列中的一个样例,得到文件名和序列化的 Example Protocol Buffer
  • 按照字段格式,使用tf.parse_single_example() 解码器对上述序列化的 Example Protocol Buffer的一个样例进行解码,返回一个 dict(mapping feature keys to Tensor and SparseTensor values)
  • 通过tf.decode_raw()函数将字符串解析成图像对应的像素数组、tf.cast()函数转换标签的数据类型
  • 图像预处理
  • 构造批处理器tf.train.shuffle_batch,来产生一个批次的数据,用于神经网络的输入
def read_and_decode(filenames, batch_size, num_epochs=None):
 '''read and decode tfrecord file, generate (image, label) batches
 Args:
 filenames: the directory of tfrecord filenames, list
 batch_size: number of images in each batch
 num_epochs: None, cycle through the strings in string_tensor an unlimited number of times
 Returns:
 image: 4D tensor - [batch_size, width, height, channel]
 label: 1D tensor - [batch_size]
 '''
 # Creates a FIFO queue for holding the filenames until the reader needs them
 filename_queue = tf.train.string_input_producer(filenames, num_epochs=num_epochs, shuffle=True)
 # 创建一个实例对象 reader, 用于读取 TFRecord 中的样例
 reader = tf.TFRecordReader()
 # 调用实例对象 reader 的 read 方法,读取文件名队列中的一个样例,得到文件名和序列化的协议内存块
 _, serialized_example = reader.read(filename_queue)
 # 按照字段格式,解析读入的一个样例(序列化的协议内存块)
 img_features = tf.parse_single_example(
 serialized_example,
 features={
 'label': tf.FixedLenFeature([], tf.int64),
 'image_raw': tf.FixedLenFeature([], tf.string),
 })
 # 将字符串解析成图像对应的像素数组 Tensor("DecodeRaw:0", shape=(?,), dtype=uint8) 
 # 注意:转成字符串之前是什么类型的数据,那么这里的参数就要填成对应的类型,否则会报错
 image = tf.decode_raw(img_features['image_raw'], tf.uint8) 
 # Tensor("Cast:0", shape=(), dtype=int32) 
 label = tf.cast(img_features['label'], tf.int32) 
 ################***** Preprocessing *****####################
 # 图像预处理(resize, reshape, crop, flip, distortion, per_image_standardization ......)
 image.set_shape([FLAGS.height, FLAGS.width, FLAGS.depth]) # 将图片内容转换成多维数组形式
 image = tf.image.resize_images(image, [48, 160]) # 统一图片的尺寸
 ...
 ...
 ...
 ############***** 构造批处理器,来产生一个批次的数据 *****##############
 # num_threads:可以指定多个线程同时执行入队操作(数据读取和预处理),通过队列实现多线程处理机制
 # capacity: 队列中最多可以存储的样例个数
 # min_after_dequeue:限制了出队时队列中元素的最少个数,从而保证随机打乱顺序的作用
 image_batch, label_batch = tf.train.shuffle_batch([image, label],
 batch_size=batch_size,
 num_threads=16,
 capacity=min_queue_examples + 3 * batch_size,
 min_after_dequeue = min_queue_examples)
 return image_batch, tf.reshape(label_batch, [batch_size])
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56

3、将 batch 数据喂入计算图并开始训练、验证、测试等

filenames = tf.train.match_filenames_once(os.path.join(FLAGS.data_dir, 'train_*.tfrecords'))
image_batch, label_batch = read_and_decode(filenames, batch_size=BATCH_SIZE)
# tf.train.string_input_producer() 定义了一个局部变量 num_epochs,所以使用前要对其初始化
init = tf.group(tf.global_variables_initializer(), tf.local_variables_initializer())
with tf.Session() as sess:
 sess.run(init)
 # 声明一个 tf.train.Coordinator() 对象来协同多个线程的工作
 coord = tf.train.Coordinator() 
 # 使用 tf.train.start_queue_runners() 之后,才会开始填充队列 
 threads = tf.train.start_queue_runners(sess=sess, coord=coord) 
 try:
 # 运行 FLAGS.iteration 个 batch
 for itr in range(FLAGS.iteration): 
 # just plot one batch size
 image, label = sess.run([image_batch, label_batch])
 plot_images(image, label)
 except tf.errors.OutOfRangeError:
 print('Done training -- epoch limit reached')
 finally:
 coord.request_stop() # 通知其它线程退出,同时 corrd.should_stop()被设置成 True
 # 等待所有的线程退出 
 coord.join(threads)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30

相关推荐

为何越来越多的编程语言使用JSON(为什么编程)

JSON是JavascriptObjectNotation的缩写,意思是Javascript对象表示法,是一种易于人类阅读和对编程友好的文本数据传递方法,是JavaScript语言规范定义的一个子...

何时在数据库中使用 JSON(数据库用json格式存储)

在本文中,您将了解何时应考虑将JSON数据类型添加到表中以及何时应避免使用它们。每天?分享?最新?软件?开发?,Devops,敏捷?,测试?以及?项目?管理?最新?,最热门?的?文章?,每天?花?...

MySQL 从零开始:05 数据类型(mysql数据类型有哪些,并举例)

前面的讲解中已经接触到了表的创建,表的创建是对字段的声明,比如:上述语句声明了字段的名称、类型、所占空间、默认值和是否可以为空等信息。其中的int、varchar、char和decimal都...

JSON对象花样进阶(json格式对象)

一、引言在现代Web开发中,JSON(JavaScriptObjectNotation)已经成为数据交换的标准格式。无论是从前端向后端发送数据,还是从后端接收数据,JSON都是不可或缺的一部分。...

深入理解 JSON 和 Form-data(json和formdata提交区别)

在讨论现代网络开发与API设计的语境下,理解客户端和服务器间如何有效且可靠地交换数据变得尤为关键。这里,特别值得关注的是两种主流数据格式:...

JSON 语法(json 语法 priority)

JSON语法是JavaScript语法的子集。JSON语法规则JSON语法是JavaScript对象表示法语法的子集。数据在名称/值对中数据由逗号分隔花括号保存对象方括号保存数组JS...

JSON语法详解(json的语法规则)

JSON语法规则JSON语法是JavaScript对象表示法语法的子集。数据在名称/值对中数据由逗号分隔大括号保存对象中括号保存数组注意:json的key是字符串,且必须是双引号,不能是单引号...

MySQL JSON数据类型操作(mysql的json)

概述mysql自5.7.8版本开始,就支持了json结构的数据存储和查询,这表明了mysql也在不断的学习和增加nosql数据库的有点。但mysql毕竟是关系型数据库,在处理json这种非结构化的数据...

JSON的数据模式(json数据格式示例)

像XML模式一样,JSON数据格式也有Schema,这是一个基于JSON格式的规范。JSON模式也以JSON格式编写。它用于验证JSON数据。JSON模式示例以下代码显示了基本的JSON模式。{"...

前端学习——JSON格式详解(后端json格式)

JSON(JavaScriptObjectNotation)是一种轻量级的数据交换格式。易于人阅读和编写。同时也易于机器解析和生成。它基于JavaScriptProgrammingLa...

什么是 JSON:详解 JSON 及其优势(什么叫json)

现在程序员还有谁不知道JSON吗?无论对于前端还是后端,JSON都是一种常见的数据格式。那么JSON到底是什么呢?JSON的定义...

PostgreSQL JSON 类型:处理结构化数据

PostgreSQL提供JSON类型,以存储结构化数据。JSON是一种开放的数据格式,可用于存储各种类型的值。什么是JSON类型?JSON类型表示JSON(JavaScriptO...

JavaScript:JSON、三种包装类(javascript 包)

JOSN:我们希望可以将一个对象在不同的语言中进行传递,以达到通信的目的,最佳方式就是将一个对象转换为字符串的形式JSON(JavaScriptObjectNotation)-JS的对象表示法...

Python数据分析 只要1分钟 教你玩转JSON 全程干货

Json简介:Json,全名JavaScriptObjectNotation,JSON(JavaScriptObjectNotation(记号、标记))是一种轻量级的数据交换格式。它基于J...

比较一下JSON与XML两种数据格式?(json和xml哪个好)

JSON(JavaScriptObjectNotation)和XML(eXtensibleMarkupLanguage)是在日常开发中比较常用的两种数据格式,它们主要的作用就是用来进行数据的传...

取消回复欢迎 发表评论:

请填写验证码