百度360必应搜狗淘宝本站头条
当前位置:网站首页 > 编程字典 > 正文

解析生成敌对网络

toyiye 2024-06-21 12:07 9 浏览 0 评论

在本教程中,您将了解什么是生成敌对网络(GAN),而不涉及数学的细节。之后,您将学习如何编写一个可以创建数字的简单GAN!

比喻

理解GAN的最简单方法是通过一个简单的比喻:

假设有一家商店从顾客那里购买某些种类的葡萄酒,他们以后会再销售。

然而,有些恶意的顾客为了获得金钱而出售假酒。在这种情况下,店主必须能够区分假酒和正品葡萄酒。

您可以想象,最初,伪造者在尝试出售假酒时可能会犯很多错误,并且店主很容易认定该酒不是真实的。由于这些失败,伪造者会继续尝试使用不同的技术来模拟真正的葡萄酒,有些最终会成功。现在,伪造者知道某些技术已经超过了店主的检测,他可以开始进一步改进基于这些技术的假酒。

同时,店主可能会从其他店主或葡萄酒专家那里得到一些反馈,说明她拥有的一些葡萄酒不是原装的。这意味着店主必须改善她是如何确定葡萄酒是伪造的还是真实的。伪造者的目标是制造与真实葡萄酒无法区分的葡萄酒,而店主的目标是准确地分辨葡萄酒是否真实。

这种来回的竞争是GAN背后的主要思想。

生成敌对网络的组成部分

使用上面的例子,我们可以想出一个GAN的体系结构。

在GANs中有两个主要的组件:生成器和鉴别器。在这个例子中,店主被称为“鉴别器网络”,通常是一个卷积神经网络(因为GANs主要用于图像任务),它给图像分配一个真实的概率。

forger被称为生成网络,也是典型的卷积神经网络(with deconvolution layers)。该网络采用一些噪声矢量并输出图像。当训练生成网络时,它会了解图像的哪些区域可以改进/改变,这样鉴别器就很难区分生成的图像和真实的图像。

生成网络不断地生成与真实图像更接近的图像,而辨别网络则试图确定真实图像和假图像之间的区别。最终的目标是有一个生成的网络,它能产生与真实图像难以区分的图像。

生成网络不断生成更接近真实图像的图像,而辨别网络试图确定真实图像和假图像之间的差异。最终的目标是建立一个可生成与真实图像无法区分的图像的生成网络。

一个简单的Keras生成对抗网络

现在您已了解GAN是什么以及它们的主要组成部分,现在我们可以开始编写一个非常简单的代码。您将使用Keras,如果您不熟悉此Python库,则应在继续之前阅读相关教程。

您需要做的第一件事是通过以下方式安装以下软件包pip:

- keras

- matplotlib

- tensorflow

- tqdm

您将matplotlib用于绘制tensorflowKeras后端库,并且tqdm为每个次迭代显示一个奇特的进度条。

下一步是创建一个Python脚本。在这个脚本中,你首先需要导入你将要使用的所有模块和函数。在使用它们时将给出每个解释。

import os

import numpy as np

import matplotlib.pyplot as plt

from tqdm import tqdm

from keras.layers import Input

from keras.models import Model, Sequential

from keras.layers.core import Dense, Dropout

from keras.layers.advanced_activations import LeakyReLU

from keras.datasets import mnist

from keras.optimizers import Adam

from keras import initializers

设置一些变量:

# Let Keras know that we are using tensorflow as our backend engine

os.environ["KERAS_BACKEND"] = "tensorflow"

# To make sure that we can reproduce the experiment and get the same results

np.random.seed(10)

# The dimension of our random noise vector.

random_dim = 100

在开始构建鉴别器和生成器之前,您应该首先收集并预处理数据。您将使用流行的MNIST数据集,该数据集具有一组从0到9范围内的单个数字的图像。

MINST数字的例子

def load_minst_data():

# load the data

(x_train, y_train), (x_test, y_test) = mnist.load_data()

# normalize our inputs to be in the range[-1, 1]

x_train = (x_train.astype(np.float32) - 127.5)/127.5

# convert x_train with a shape of (60000, 28, 28) to (60000, 784) so we have

# 784 columns per row

x_train = x_train.reshape(60000, 784)

return (x_train, y_train, x_test, y_test)

请注意,这mnist.load_data()是Keras的一部分,并允许您轻松将MNIST数据集导入您的工作区。

现在,你可以创建你的生成器和鉴别器网络。您将为这两个网络使用Adam优化器。对于生成器和鉴别器,您将创建一个带有三个隐藏层的神经网络,激活函数为Leaky Relu。您还应该为鉴别器添加dropout层,以提高其对未见图像的鲁棒性。

def get_optimizer():

return Adam(lr=0.0002, beta_1=0.5)

def get_generator(optimizer):

generator = Sequential()

generator.add(Dense(256, input_dim=random_dim, kernel_initializer=initializers.RandomNormal(stddev=0.02)))

generator.add(LeakyReLU(0.2))

generator.add(Dense(512))

generator.add(LeakyReLU(0.2))

generator.add(Dense(1024))

generator.add(LeakyReLU(0.2))

generator.add(Dense(784, activation='tanh'))

generator.compile(loss='binary_crossentropy', optimizer=optimizer)

return generator

def get_discriminator(optimizer):

discriminator = Sequential()

discriminator.add(Dense(1024, input_dim=784, kernel_initializer=initializers.RandomNormal(stddev=0.02)))

discriminator.add(LeakyReLU(0.2))

discriminator.add(Dropout(0.3))

discriminator.add(Dense(512))

discriminator.add(LeakyReLU(0.2))

discriminator.add(Dropout(0.3))

discriminator.add(Dense(256))

discriminator.add(LeakyReLU(0.2))

discriminator.add(Dropout(0.3))

discriminator.add(Dense(1, activation='sigmoid'))

discriminator.compile(loss='binary_crossentropy', optimizer=optimizer)

return discriminator

终于到了将generator 和discriminator 放在一起的时候了!

def get_gan_network(discriminator, random_dim, generator, optimizer):

# We initially set trainable to False since we only want to train either the

# generator or discriminator at a time

discriminator.trainable = False

# gan input (noise) will be 100-dimensional vectors

gan_input = Input(shape=(random_dim,))

# the output of the generator (an image)

x = generator(gan_input)

# get the output of the discriminator (probability if the image is real or not)

gan_output = discriminator(x)

gan = Model(inputs=gan_input, outputs=gan_output)

gan.compile(loss='binary_crossentropy', optimizer=optimizer)

return gan

为了保持完整性,您可以创建一个功能,每20次迭代将保存您生成的图像。

def plot_generated_images(epoch, generator, examples=100, dim=(10, 10), figsize=(10, 10)):

noise = np.random.normal(0, 1, size=[examples, random_dim])

generated_images = generator.predict(noise)

generated_images = generated_images.reshape(examples, 28, 28)

plt.figure(figsize=figsize)

for i in range(generated_images.shape[0]):

plt.subplot(dim[0], dim[1], i+1)

plt.imshow(generated_images[i], interpolation='nearest', cmap='gray_r')

plt.axis('off')

plt.tight_layout()

plt.savefig('gan_generated_image_epoch_%d.png' % epoch)

您现在已经编码了大部分网络。剩下的就是训练这个网络,并看看你创建的图像。

def train(epochs=1, batch_size=128):

# Get the training and testing data

x_train, y_train, x_test, y_test = load_minst_data()

# Split the training data into batches of size 128

batch_count = x_train.shape[0] / batch_size

# Build our GAN netowrk

adam = get_optimizer()

generator = get_generator(adam)

discriminator = get_discriminator(adam)

gan = get_gan_network(discriminator, random_dim, generator, adam)

for e in xrange(1, epochs+1):

print '-'*15, 'Epoch %d' % e, '-'*15

for _ in tqdm(xrange(batch_count)):

# Get a random set of input noise and images

noise = np.random.normal(0, 1, size=[batch_size, random_dim])

image_batch = x_train[np.random.randint(0, x_train.shape[0], size=batch_size)]

# Generate fake MNIST images

generated_images = generator.predict(noise)

X = np.concatenate([image_batch, generated_images])

# Labels for generated and real data

y_dis = np.zeros(2*batch_size)

# One-sided label smoothing

y_dis[:batch_size] = 0.9

# Train discriminator

discriminator.trainable = True

discriminator.train_on_batch(X, y_dis)

# Train generator

noise = np.random.normal(0, 1, size=[batch_size, random_dim])

y_gen = np.ones(batch_size)

discriminator.trainable = False

gan.train_on_batch(noise, y_gen)

if e == 1 or e % 20 == 0:

plot_generated_images(e, generator)

if __name__ == '__main__':

train(400, 128)

在400个epochs训练之后,您可以查看生成的图像。看着第一次迭代后产生的图像,你可以看到,它没有任何真正的结构,观察40次迭代后的图像,数字开始成形,最后,在400个epochs产生的图像时代显示清晰的数字虽然仍然认不出来。

相关推荐

为何越来越多的编程语言使用JSON(为什么编程)

JSON是JavascriptObjectNotation的缩写,意思是Javascript对象表示法,是一种易于人类阅读和对编程友好的文本数据传递方法,是JavaScript语言规范定义的一个子...

何时在数据库中使用 JSON(数据库用json格式存储)

在本文中,您将了解何时应考虑将JSON数据类型添加到表中以及何时应避免使用它们。每天?分享?最新?软件?开发?,Devops,敏捷?,测试?以及?项目?管理?最新?,最热门?的?文章?,每天?花?...

MySQL 从零开始:05 数据类型(mysql数据类型有哪些,并举例)

前面的讲解中已经接触到了表的创建,表的创建是对字段的声明,比如:上述语句声明了字段的名称、类型、所占空间、默认值和是否可以为空等信息。其中的int、varchar、char和decimal都...

JSON对象花样进阶(json格式对象)

一、引言在现代Web开发中,JSON(JavaScriptObjectNotation)已经成为数据交换的标准格式。无论是从前端向后端发送数据,还是从后端接收数据,JSON都是不可或缺的一部分。...

深入理解 JSON 和 Form-data(json和formdata提交区别)

在讨论现代网络开发与API设计的语境下,理解客户端和服务器间如何有效且可靠地交换数据变得尤为关键。这里,特别值得关注的是两种主流数据格式:...

JSON 语法(json 语法 priority)

JSON语法是JavaScript语法的子集。JSON语法规则JSON语法是JavaScript对象表示法语法的子集。数据在名称/值对中数据由逗号分隔花括号保存对象方括号保存数组JS...

JSON语法详解(json的语法规则)

JSON语法规则JSON语法是JavaScript对象表示法语法的子集。数据在名称/值对中数据由逗号分隔大括号保存对象中括号保存数组注意:json的key是字符串,且必须是双引号,不能是单引号...

MySQL JSON数据类型操作(mysql的json)

概述mysql自5.7.8版本开始,就支持了json结构的数据存储和查询,这表明了mysql也在不断的学习和增加nosql数据库的有点。但mysql毕竟是关系型数据库,在处理json这种非结构化的数据...

JSON的数据模式(json数据格式示例)

像XML模式一样,JSON数据格式也有Schema,这是一个基于JSON格式的规范。JSON模式也以JSON格式编写。它用于验证JSON数据。JSON模式示例以下代码显示了基本的JSON模式。{"...

前端学习——JSON格式详解(后端json格式)

JSON(JavaScriptObjectNotation)是一种轻量级的数据交换格式。易于人阅读和编写。同时也易于机器解析和生成。它基于JavaScriptProgrammingLa...

什么是 JSON:详解 JSON 及其优势(什么叫json)

现在程序员还有谁不知道JSON吗?无论对于前端还是后端,JSON都是一种常见的数据格式。那么JSON到底是什么呢?JSON的定义...

PostgreSQL JSON 类型:处理结构化数据

PostgreSQL提供JSON类型,以存储结构化数据。JSON是一种开放的数据格式,可用于存储各种类型的值。什么是JSON类型?JSON类型表示JSON(JavaScriptO...

JavaScript:JSON、三种包装类(javascript 包)

JOSN:我们希望可以将一个对象在不同的语言中进行传递,以达到通信的目的,最佳方式就是将一个对象转换为字符串的形式JSON(JavaScriptObjectNotation)-JS的对象表示法...

Python数据分析 只要1分钟 教你玩转JSON 全程干货

Json简介:Json,全名JavaScriptObjectNotation,JSON(JavaScriptObjectNotation(记号、标记))是一种轻量级的数据交换格式。它基于J...

比较一下JSON与XML两种数据格式?(json和xml哪个好)

JSON(JavaScriptObjectNotation)和XML(eXtensibleMarkupLanguage)是在日常开发中比较常用的两种数据格式,它们主要的作用就是用来进行数据的传...

取消回复欢迎 发表评论:

请填写验证码