百度360必应搜狗淘宝本站头条
当前位置:网站首页 > 编程字典 > 正文

如何从Keras中的图像文件中进行mixup training

toyiye 2024-06-23 18:32 21 浏览 0 评论

什么是mixup training?

论文mixup: BEYOND EMPIRICAL RISK MINIMIZATION(https://arxiv.org/pdf/1710.09412.pdf)提供了传统图像增强技术的替代方案,如缩放和旋转。通过两个现有实例的加权线性插值形成一个新的实例。

(xi; yi)和(xj; yj)是从训练数据中随机抽取的两个例子,λ∈[0; 1]。实际上,λ是从β分布中随机取样的,即β(α;α)。

α∈[0.1; 0.4]导致性能提高,较小的α产生较少的mixup效果,而对于较大的α,mixup会导致underfitting。

正如您在下图中所看到的,给定小的α= 0.2,β分布采样更接近0或1的值,使得mixup结果更接近于两个实例中的一个。

import numpy as np
import scipy.stats as stats
import matplotlib.pyplot as plt
alpha = 0.2
array = np.random.beta(alpha, alpha, 5000)
h = sorted(array) #sorted
fit = stats.norm.pdf(h, np.mean(h), np.std(h)) #this is a fitting indeed
plt.hist(h,normed=True)
plt.title('Beta distribution')
plt.show()

mixup training有哪些好处?

传统的数据增强(如Keras ImageDataGenerator类中提供的数据增强)可以持续改进泛化,但该过程依赖于机器学习数据集,因此需要使用专业知识。

此外,数据增强不会模拟不同类的实例之间的关系。

另一方面,

  • Mixup是一种与数据无关的数据增强例程。
  • 它使决策边界从一个类到另一个类线性地过渡,从而提供更平滑的不确定性估计。
  • 它减少了损坏标签的存储
  • 增强了对抗性实例的鲁棒性,稳定了生成对抗性网络的训练。

Keras中的Mixup图像数据生成器

让我们实现一个图像数据生成器,它从文件中读取图像,并使用Keras model.fit_generator()开箱即用。Python代码如下:

import numpy as np
train_dir = "./data"
batch_size = 5
validation_split = 0.3
img_height = 150
img_width = 150
epochs = 10
class MixupImageDataGenerator():
 def __init__(self, generator, directory, batch_size, img_height, img_width, alpha=0.2, subset=None):
 """Constructor for mixup image data generator.
 Arguments:
 generator {object} -- An instance of Keras ImageDataGenerator.
 directory {str} -- Image directory.
 batch_size {int} -- Batch size.
 img_height {int} -- Image height in pixels.
 img_width {int} -- Image width in pixels.
 Keyword Arguments:
 alpha {float} -- Mixup beta distribution alpha parameter. (default: {0.2})
 subset {str} -- 'training' or 'validation' if validation_split is specified in
 `generator` (ImageDataGenerator).(default: {None})
 """
 self.batch_index = 0
 self.batch_size = batch_size
 self.alpha = alpha
 # First iterator yielding tuples of (x, y)
 self.generator1 = generator.flow_from_directory(directory,
 target_size=(
 img_height, img_width),
 class_mode="categorical",
 batch_size=batch_size,
 shuffle=True,
 subset=subset)
 # Second iterator yielding tuples of (x, y)
 self.generator2 = generator.flow_from_directory(directory,
 target_size=(
 img_height, img_width),
 class_mode="categorical",
 batch_size=batch_size,
 shuffle=True,
 subset=subset)
 # Number of images across all classes in image directory.
 self.n = self.generator1.samples
 def reset_index(self):
 """Reset the generator indexes array.
 """
 self.generator1._set_index_array()
 self.generator2._set_index_array()
 def on_epoch_end(self):
 self.reset_index()
 def reset(self):
 self.batch_index = 0
 def __len__(self):
 # round up
 return (self.n + self.batch_size - 1) // self.batch_size
 def get_steps_per_epoch(self):
 """Get number of steps per epoch based on batch size and
 number of images.
 Returns:
 int -- steps per epoch.
 """
 return self.n // self.batch_size
 def __next__(self):
 """Get next batch input/output pair.
 Returns:
 tuple -- batch of input/output pair, (inputs, outputs).
 """
 if self.batch_index == 0:
 self.reset_index()
 current_index = (self.batch_index * self.batch_size) % self.n
 if self.n > current_index + self.batch_size:
 self.batch_index += 1
 else:
 self.batch_index = 0
 # random sample the lambda value from beta distribution.
 l = np.random.beta(self.alpha, self.alpha, self.batch_size)
 X_l = l.reshape(self.batch_size, 1, 1, 1)
 y_l = l.reshape(self.batch_size, 1)
 # Get a pair of inputs and outputs from two iterators.
 X1, y1 = self.generator1.next()
 X2, y2 = self.generator2.next()
 # Perform the mixup.
 X = X1 * X_l + X2 * (1 - X_l)
 y = y1 * y_l + y2 * (1 - y_l)
 return X, y
 def __iter__(self):
 while True:
 yield next(self)

mixup生成器的核心由一对迭代器组成,这些迭代器一次一个地从一个目录中随机采样图像,并在该__next__方法中执行mixup。

然后,您可以创建用于拟合机器学习模型的训练和验证生成器,注意我们不在验证生成器中使用mixup。

train_dir = "./data"
batch_size = 5
validation_split = 0.3
img_height = 150
img_width = 150
epochs = 10
# Optional additional image augmentation with ImageDataGenerator.
input_imgen = ImageDataGenerator(
 rescale=1./255,
 rotation_range=5,
 width_shift_range=0.05,
 height_shift_range=0,
 shear_range=0.05,
 zoom_range=0,
 brightness_range=(1, 1.3),
 horizontal_flip=True,
 fill_mode='nearest',
 validation_split=validation_split)
# Create training and validation generator.
train_generator = MixupImageDataGenerator(generator=input_imgen,
 directory=train_dir,
 batch_size=batch_size,
 img_height=img_height,
 img_width=img_height,
 subset='training')
validation_generator = input_imgen.flow_from_directory(train_dir,
 target_size=(
 img_height, img_width),
 class_mode="categorical",
 batch_size=batch_size,
 shuffle=True,
 subset='validation')
print('training steps: ', train_generator.get_steps_per_epoch())
print('validation steps: ', validation_generator.samples // batch_size)

像往常一样构建Keras图像分类机器学习模型。

from tensorflow.keras.applications import VGG16
conv_base = VGG16(weights='imagenet',
 include_top=False,
 input_shape=(img_height, img_width, 3))
from tensorflow.keras import models
from tensorflow.keras import layers
from tensorflow.keras import optimizers
model = models.Sequential()
model.add(conv_base)
model.add(layers.Flatten())
model.add(layers.Dense(256, activation='relu'))
model.add(layers.Dense(4, activation='sigmoid'))
conv_base.trainable = False
model.compile(optimizer=optimizers.RMSprop(lr=2e-5),
 loss='binary_crossentropy',
 metrics=['acc'])

训练机器学习模型

train_generator.reset()
validation_generator.reset()
# Start the traning.
history = model.fit_generator(
 train_generator,
 steps_per_epoch=train_generator.get_steps_per_epoch(),
 validation_data=validation_generator,
 validation_steps=validation_generator.samples // batch_size,
 epochs=epochs)

我们可以使用以下Python代码段可视化一批Mixup 图像和标签。

sample_x, sample_y = next(train_generator)
for i in range(batch_size):
 display(image.array_to_img(sample_x[i]))
print(sample_y)

结论

您可能认为一次Mixup 超过2个实例可能会导致更好的训练,相反,将三个或更多的例子与从beta分布的多元泛化中取样的权重组合,并不能提供进一步的增益,而是增加了Mixup 的计算成本。此外,仅在具有相同标签的输入之间进行插值并不会导致性能的提高。

相关推荐

为何越来越多的编程语言使用JSON(为什么编程)

JSON是JavascriptObjectNotation的缩写,意思是Javascript对象表示法,是一种易于人类阅读和对编程友好的文本数据传递方法,是JavaScript语言规范定义的一个子...

何时在数据库中使用 JSON(数据库用json格式存储)

在本文中,您将了解何时应考虑将JSON数据类型添加到表中以及何时应避免使用它们。每天?分享?最新?软件?开发?,Devops,敏捷?,测试?以及?项目?管理?最新?,最热门?的?文章?,每天?花?...

MySQL 从零开始:05 数据类型(mysql数据类型有哪些,并举例)

前面的讲解中已经接触到了表的创建,表的创建是对字段的声明,比如:上述语句声明了字段的名称、类型、所占空间、默认值和是否可以为空等信息。其中的int、varchar、char和decimal都...

JSON对象花样进阶(json格式对象)

一、引言在现代Web开发中,JSON(JavaScriptObjectNotation)已经成为数据交换的标准格式。无论是从前端向后端发送数据,还是从后端接收数据,JSON都是不可或缺的一部分。...

深入理解 JSON 和 Form-data(json和formdata提交区别)

在讨论现代网络开发与API设计的语境下,理解客户端和服务器间如何有效且可靠地交换数据变得尤为关键。这里,特别值得关注的是两种主流数据格式:...

JSON 语法(json 语法 priority)

JSON语法是JavaScript语法的子集。JSON语法规则JSON语法是JavaScript对象表示法语法的子集。数据在名称/值对中数据由逗号分隔花括号保存对象方括号保存数组JS...

JSON语法详解(json的语法规则)

JSON语法规则JSON语法是JavaScript对象表示法语法的子集。数据在名称/值对中数据由逗号分隔大括号保存对象中括号保存数组注意:json的key是字符串,且必须是双引号,不能是单引号...

MySQL JSON数据类型操作(mysql的json)

概述mysql自5.7.8版本开始,就支持了json结构的数据存储和查询,这表明了mysql也在不断的学习和增加nosql数据库的有点。但mysql毕竟是关系型数据库,在处理json这种非结构化的数据...

JSON的数据模式(json数据格式示例)

像XML模式一样,JSON数据格式也有Schema,这是一个基于JSON格式的规范。JSON模式也以JSON格式编写。它用于验证JSON数据。JSON模式示例以下代码显示了基本的JSON模式。{"...

前端学习——JSON格式详解(后端json格式)

JSON(JavaScriptObjectNotation)是一种轻量级的数据交换格式。易于人阅读和编写。同时也易于机器解析和生成。它基于JavaScriptProgrammingLa...

什么是 JSON:详解 JSON 及其优势(什么叫json)

现在程序员还有谁不知道JSON吗?无论对于前端还是后端,JSON都是一种常见的数据格式。那么JSON到底是什么呢?JSON的定义...

PostgreSQL JSON 类型:处理结构化数据

PostgreSQL提供JSON类型,以存储结构化数据。JSON是一种开放的数据格式,可用于存储各种类型的值。什么是JSON类型?JSON类型表示JSON(JavaScriptO...

JavaScript:JSON、三种包装类(javascript 包)

JOSN:我们希望可以将一个对象在不同的语言中进行传递,以达到通信的目的,最佳方式就是将一个对象转换为字符串的形式JSON(JavaScriptObjectNotation)-JS的对象表示法...

Python数据分析 只要1分钟 教你玩转JSON 全程干货

Json简介:Json,全名JavaScriptObjectNotation,JSON(JavaScriptObjectNotation(记号、标记))是一种轻量级的数据交换格式。它基于J...

比较一下JSON与XML两种数据格式?(json和xml哪个好)

JSON(JavaScriptObjectNotation)和XML(eXtensibleMarkupLanguage)是在日常开发中比较常用的两种数据格式,它们主要的作用就是用来进行数据的传...

取消回复欢迎 发表评论:

请填写验证码