百度360必应搜狗淘宝本站头条
当前位置:网站首页 > 编程字典 > 正文

官方资源帖!手把手教你在TF2.0中实现CycleGAN,推特上百赞

toyiye 2024-06-21 12:21 10 浏览 0 评论

铜灵 发自 凹非寺

量子位 出品| 公众号 QbitAI

CycleGAN,一个可以将一张图像的特征迁移到另一张图像的酷算法,此前可以完成马变斑马、冬天变夏天、苹果变桔子等一颗赛艇的效果。



这行被顶会ICCV收录的研究自提出后,就为图形学等领域的技术人员所用,甚至还成为不少艺术家用来创作的工具。



也是目前大火的“换脸”技术的老前辈了。



如果你还没学会这项厉害的研究,那这次一定要抓紧上车了。

现在,TensorFlow开始手把手教你,在TensorFlow 2.0中CycleGAN实现大法。

这个官方教程贴几天内收获了满满人气,获得了Google AI工程师、哥伦比亚大学数据科学研究所Josh Gordon的推荐,推特上已近600赞。



有国外网友称赞太棒,表示很高兴看到TensorFlow 2.0教程中涵盖了最先进的模型。

这份教程全面详细,想学CycleGAN不能错过这个:

详细内容

在TensorFlow 2.0中实现CycleGAN,只要7个步骤就可以了。

1、设置输入Pipeline

安装tensorflow_examples包,用于导入生成器和鉴别器。

!pip install -q git+https://github.com/tensorflow/examples.git

!pip install -q tensorflow-gpu==2.0.0-beta1
import tensorflow as tf
from __future__ import absolute_import, division, print_function, unicode_literals
import tensorflow_datasets as tfds
from tensorflow_examples.models.pix2pix import pix2pix
import os
import time
import matplotlib.pyplot as plt
from IPython.display import clear_output
tfds.disable_progress_bar()
AUTOTUNE = tf.data.experimental.AUTOTUNE

2、输入pipeline

在这个教程中,我们主要学习马到斑马的图像转换,如果想寻找类似的数据集,可以前往:

https://www.tensorflow.org/datasets/datasets#cycle_gan

在CycleGAN论文中也提到,将随机抖动( Jitter )和镜像应用到训练集中,这是避免过度拟合的图像增强技术。

和在Pix2Pix中的操作类似,在随机抖动中吗,图像大小被调整成286×286,然后随机裁剪为256×256。

在随机镜像中吗,图像随机水平翻转,即从左到右进行翻转。

dataset, metadata = tfds.load('cycle_gan/horse2zebra',
 with_info=True, as_supervised=True)
train_horses, train_zebras = dataset['trainA'], dataset['trainB']
test_horses, test_zebras = dataset['testA'], dataset['testB']




BUFFER_SIZE = 1000
BATCH_SIZE = 1
IMG_WIDTH = 256
IMG_HEIGHT = 256
def random_crop(image):
 cropped_image = tf.image.random_crop(
 image, size=[IMG_HEIGHT, IMG_WIDTH, 3])
 return cropped_image


# normalizing the images to [-1, 1]
def normalize(image):
 image = tf.cast(image, tf.float32)
 image = (image / 127.5) - 1
 return image


def random_jitter(image):
 # resizing to 286 x 286 x 3
 image = tf.image.resize(image, [286, 286],
 method=tf.image.ResizeMethod.NEAREST_NEIGHBOR)
 # randomly cropping to 256 x 256 x 3
 image = random_crop(image)
 # random mirroring
 image = tf.image.random_flip_left_right(image)
 return image


def preprocess_image_train(image, label):
 image = random_jitter(image)
 image = normalize(image)
 return image


def preprocess_image_test(image, label):
 image = normalize(image)
 return image


train_horses = train_horses.map(
 preprocess_image_train, num_parallel_calls=AUTOTUNE).cache().shuffle(
 BUFFER_SIZE).batch(1)
train_zebras = train_zebras.map(
 preprocess_image_train, num_parallel_calls=AUTOTUNE).cache().shuffle(
 BUFFER_SIZE).batch(1)
test_horses = test_horses.map(
 preprocess_image_test, num_parallel_calls=AUTOTUNE).cache().shuffle(
 BUFFER_SIZE).batch(1)
test_zebras = test_zebras.map(
 preprocess_image_test, num_parallel_calls=AUTOTUNE).cache().shuffle(
 BUFFER_SIZE).batch(1)


sample_horse = next(iter(train_horses))
sample_zebra = next(iter(train_zebras))


plt.subplot(121)
plt.title('Horse')
plt.imshow(sample_horse[0] * 0.5 + 0.5)
plt.subplot(122)
plt.title('Horse with random jitter')
plt.imshow(random_jitter(sample_horse[0]) * 0.5 + 0.5)



plt.subplot(121)
plt.title('Zebra')
plt.imshow(sample_zebra[0] * 0.5 + 0.5)
plt.subplot(122)
plt.title('Zebra with random jitter')
plt.imshow(random_jitter(sample_zebra[0]) * 0.5 + 0.5)




3、导入并重新使用Pix2Pix模型

通过安装tensorflow_examples包,从Pix2Pix中导入生成器和鉴别器。

这个教程中使用的模型体系结构与Pix2Pix中很类似,但也有一些差异,比如Cyclegan使用的是实例规范化而不是批量规范化,比如Cyclegan论文使用的是修改后的resnet生成器等。

我们训练两个生成器(G和F)和两个鉴别器(X和Y)。生成器G架构图像X转换为图像Y,生成器F将图像Y转换为图像X。

鉴别器D_X区分图像X和生成的图像X(F(Y)),辨别器D_Y区分图像Y和生成的图像Y(G(X))。



OUTPUT_CHANNELS = 3
generator_g = pix2pix.unet_generator(OUTPUT_CHANNELS, norm_type='instancenorm')
generator_f = pix2pix.unet_generator(OUTPUT_CHANNELS, norm_type='instancenorm')
discriminator_x = pix2pix.discriminator(norm_type='instancenorm', target=False)
discriminator_y = pix2pix.discriminator(norm_type='instancenorm', target=False)


to_zebra = generator_g(sample_horse)
to_horse = generator_f(sample_zebra)
plt.figure(figsize=(8, 8))
contrast = 8
plt.subplot(221)
plt.title('Horse')
plt.imshow(sample_horse[0] * 0.5 + 0.5)
plt.subplot(222)
plt.title('To Zebra')
plt.imshow(to_zebra[0] * 0.5 * contrast + 0.5)
plt.subplot(223)
plt.title('Zebra')
plt.imshow(sample_zebra[0] * 0.5 + 0.5)
plt.subplot(224)
plt.title('To Horse')
plt.imshow(to_horse[0] * 0.5 * contrast + 0.5)
plt.show()




plt.figure(figsize=(8, 8))
plt.subplot(121)
plt.title('Is a real zebra?')
plt.imshow(discriminator_y(sample_zebra)[0, ..., -1], cmap='RdBu_r')
plt.subplot(122)
plt.title('Is a real horse?')
plt.imshow(discriminator_x(sample_horse)[0, ..., -1], cmap='RdBu_r')
plt.show()



4、损失函数

在CycleGAN中,因为没有用于训练的成对数据,因此无法保证输入X和目标Y在训练期间是否有意义。因此,为了强制学习正确的映射,CycleGAN中提出了“循环一致性损失”(cycle consistency loss)。

鉴别器和生成器的损失与Pix2Pix中的类似。

LAMBDA = 10


loss_obj = tf.keras.losses.BinaryCrossentropy(from_logits=True)


def discriminator_loss(real, generated):
 real_loss = loss_obj(tf.ones_like(real), real)
 generated_loss = loss_obj(tf.zeros_like(generated), generated)
 total_disc_loss = real_loss + generated_loss
 return total_disc_loss * 0.5


def generator_loss(generated):
 return loss_obj(tf.ones_like(generated), generated)


循环一致性意味着结果接近原始输入。

例如将一个句子和英语翻译成法语,再将其从法语翻译成英语后,结果与原始英文句子相同。

在循环一致性损失中,图像X通过生成器传递C产生的图像Y^,生成的图像Y^通过生成器传递F产生的图像X^,然后计算平均绝对误差X和X^。

前向循环一致性损失为:



反向循环一致性损失为:




def calc_cycle_loss(real_image, cycled_image):
 loss1 = tf.reduce_mean(tf.abs(real_image - cycled_image))
 return LAMBDA * loss1

初始化所有生成器和鉴别器的的优化:

generator_g_optimizer = tf.keras.optimizers.Adam(2e-4, beta_1=0.5)
generator_f_optimizer = tf.keras.optimizers.Adam(2e-4, beta_1=0.5)
discriminator_x_optimizer = tf.keras.optimizers.Adam(2e-4, beta_1=0.5)
discriminator_y_optimizer = tf.keras.optimizers.Adam(2e-4, beta_1=0.5)

5、检查点

checkpoint_path = "./checkpoints/train"
ckpt = tf.train.Checkpoint(generator_g=generator_g,
 generator_f=generator_f,
 discriminator_x=discriminator_x,
 discriminator_y=discriminator_y,
 generator_g_optimizer=generator_g_optimizer,
 generator_f_optimizer=generator_f_optimizer,
 discriminator_x_optimizer=discriminator_x_optimizer,
 discriminator_y_optimizer=discriminator_y_optimizer)
ckpt_manager = tf.train.CheckpointManager(ckpt, checkpoint_path, max_to_keep=5)
# if a checkpoint exists, restore the latest checkpoint.
if ckpt_manager.latest_checkpoint:
 ckpt.restore(ckpt_manager.latest_checkpoint)
 print ('Latest checkpoint restored!!')

6、训练

注意:为了使本教程的训练时间合理,本示例模型迭代次数较少(40次,论文中为200次),预测效果可能不如论文准确。

EPOCHS = 40


def generate_images(model, test_input):
 prediction = model(test_input)
 plt.figure(figsize=(12, 12))
 display_list = [test_input[0], prediction[0]]
 title = ['Input Image', 'Predicted Image']
 for i in range(2):
 plt.subplot(1, 2, i+1)
 plt.title(title[i])
 # getting the pixel values between [0, 1] to plot it.
 plt.imshow(display_list[i] * 0.5 + 0.5)
 plt.axis('off')
 plt.show()


尽管训练起来很复杂,但基本的步骤只有四个,分别为:获取预测、计算损失、使用反向传播计算梯度、将梯度应用于优化程序。

@tf.function
def train_step(real_x, real_y):
 # persistent is set to True because gen_tape and disc_tape is used more than
 # once to calculate the gradients.
 with tf.GradientTape(persistent=True) as gen_tape, tf.GradientTape(
 persistent=True) as disc_tape:
 fake_y = generator_g(real_x, training=True)
 cycled_x = generator_f(fake_y, training=True)
 fake_x = generator_f(real_y, training=True)
 cycled_y = generator_g(fake_x, training=True)
 disc_real_x = discriminator_x(real_x, training=True)
 disc_real_y = discriminator_y(real_y, training=True)
 disc_fake_x = discriminator_x(fake_x, training=True)
 disc_fake_y = discriminator_y(fake_y, training=True)
 # calculate the loss
 gen_g_loss = generator_loss(disc_fake_y)
 gen_f_loss = generator_loss(disc_fake_x)
 # Total generator loss = adversarial loss + cycle loss
 total_gen_g_loss = gen_g_loss + calc_cycle_loss(real_x, cycled_x)
 total_gen_f_loss = gen_f_loss + calc_cycle_loss(real_y, cycled_y)
 disc_x_loss = discriminator_loss(disc_real_x, disc_fake_x)
 disc_y_loss = discriminator_loss(disc_real_y, disc_fake_y)
 # Calculate the gradients for generator and discriminator
 generator_g_gradients = gen_tape.gradient(total_gen_g_loss, 
 generator_g.trainable_variables)
 generator_f_gradients = gen_tape.gradient(total_gen_f_loss, 
 generator_f.trainable_variables)
 discriminator_x_gradients = disc_tape.gradient(
 disc_x_loss, discriminator_x.trainable_variables)
 discriminator_y_gradients = disc_tape.gradient(
 disc_y_loss, discriminator_y.trainable_variables)
 # Apply the gradients to the optimizer
 generator_g_optimizer.apply_gradients(zip(generator_g_gradients, 
 generator_g.trainable_variables))
 generator_f_optimizer.apply_gradients(zip(generator_f_gradients, 
 generator_f.trainable_variables))
 discriminator_x_optimizer.apply_gradients(
 zip(discriminator_x_gradients,
 discriminator_x.trainable_variables))
 discriminator_y_optimizer.apply_gradients(
 zip(discriminator_y_gradients,
 discriminator_y.trainable_variables))


for epoch in range(EPOCHS):
 start = time.time()
 n = 0
 for image_x, image_y in tf.data.Dataset.zip((train_horses, train_zebras)):
 train_step(image_x, image_y)
 if n % 10 == 0:
 print ('.', end='')
 n+=1
 clear_output(wait=True)
 # Using a consistent image (sample_horse) so that the progress of the model
 # is clearly visible.
 generate_images(generator_g, sample_horse)
 if (epoch + 1) % 5 == 0:
 ckpt_save_path = ckpt_manager.save()
 print ('Saving checkpoint for epoch {} at {}'.format(epoch+1,
 ckpt_save_path))
 print ('Time taken for epoch {} is {} sec\n'.format(epoch + 1,
 time.time()-start))





7、使用测试集生成图像

# Run the trained model on the test dataset
for inp in test_horses.take(5):
 generate_images(generator_g, inp)






8、进阶学习方向

在上面的教程中,我们学习了如何从Pix2Pix中实现的生成器和鉴别器进一步实现CycleGAN,接下来的学习你可以尝试使用TensorFlow中的其他数据集。

你还可以用更多次的迭代改善结果,或者实现论文中修改的ResNet生成器,进行知识点的进一步巩固。

传送门

https://www.tensorflow.org/beta/tutorials/generative/cyclegan

GitHub地址:

https://github.com/tensorflow/docs/blob/master/site/en/r2/tutorials/generative/cyclegan.ipynb

— 完 —

诚挚招聘

量子位正在招募编辑/记者,工作地点在北京中关村。期待有才气、有热情的同学加入我们!相关细节,请在量子位公众号(QbitAI)对话界面,回复“招聘”两个字。

量子位 QbitAI · 头条号签约作者

?'?' ? 追踪AI技术和产品新动态

相关推荐

为何越来越多的编程语言使用JSON(为什么编程)

JSON是JavascriptObjectNotation的缩写,意思是Javascript对象表示法,是一种易于人类阅读和对编程友好的文本数据传递方法,是JavaScript语言规范定义的一个子...

何时在数据库中使用 JSON(数据库用json格式存储)

在本文中,您将了解何时应考虑将JSON数据类型添加到表中以及何时应避免使用它们。每天?分享?最新?软件?开发?,Devops,敏捷?,测试?以及?项目?管理?最新?,最热门?的?文章?,每天?花?...

MySQL 从零开始:05 数据类型(mysql数据类型有哪些,并举例)

前面的讲解中已经接触到了表的创建,表的创建是对字段的声明,比如:上述语句声明了字段的名称、类型、所占空间、默认值和是否可以为空等信息。其中的int、varchar、char和decimal都...

JSON对象花样进阶(json格式对象)

一、引言在现代Web开发中,JSON(JavaScriptObjectNotation)已经成为数据交换的标准格式。无论是从前端向后端发送数据,还是从后端接收数据,JSON都是不可或缺的一部分。...

深入理解 JSON 和 Form-data(json和formdata提交区别)

在讨论现代网络开发与API设计的语境下,理解客户端和服务器间如何有效且可靠地交换数据变得尤为关键。这里,特别值得关注的是两种主流数据格式:...

JSON 语法(json 语法 priority)

JSON语法是JavaScript语法的子集。JSON语法规则JSON语法是JavaScript对象表示法语法的子集。数据在名称/值对中数据由逗号分隔花括号保存对象方括号保存数组JS...

JSON语法详解(json的语法规则)

JSON语法规则JSON语法是JavaScript对象表示法语法的子集。数据在名称/值对中数据由逗号分隔大括号保存对象中括号保存数组注意:json的key是字符串,且必须是双引号,不能是单引号...

MySQL JSON数据类型操作(mysql的json)

概述mysql自5.7.8版本开始,就支持了json结构的数据存储和查询,这表明了mysql也在不断的学习和增加nosql数据库的有点。但mysql毕竟是关系型数据库,在处理json这种非结构化的数据...

JSON的数据模式(json数据格式示例)

像XML模式一样,JSON数据格式也有Schema,这是一个基于JSON格式的规范。JSON模式也以JSON格式编写。它用于验证JSON数据。JSON模式示例以下代码显示了基本的JSON模式。{"...

前端学习——JSON格式详解(后端json格式)

JSON(JavaScriptObjectNotation)是一种轻量级的数据交换格式。易于人阅读和编写。同时也易于机器解析和生成。它基于JavaScriptProgrammingLa...

什么是 JSON:详解 JSON 及其优势(什么叫json)

现在程序员还有谁不知道JSON吗?无论对于前端还是后端,JSON都是一种常见的数据格式。那么JSON到底是什么呢?JSON的定义...

PostgreSQL JSON 类型:处理结构化数据

PostgreSQL提供JSON类型,以存储结构化数据。JSON是一种开放的数据格式,可用于存储各种类型的值。什么是JSON类型?JSON类型表示JSON(JavaScriptO...

JavaScript:JSON、三种包装类(javascript 包)

JOSN:我们希望可以将一个对象在不同的语言中进行传递,以达到通信的目的,最佳方式就是将一个对象转换为字符串的形式JSON(JavaScriptObjectNotation)-JS的对象表示法...

Python数据分析 只要1分钟 教你玩转JSON 全程干货

Json简介:Json,全名JavaScriptObjectNotation,JSON(JavaScriptObjectNotation(记号、标记))是一种轻量级的数据交换格式。它基于J...

比较一下JSON与XML两种数据格式?(json和xml哪个好)

JSON(JavaScriptObjectNotation)和XML(eXtensibleMarkupLanguage)是在日常开发中比较常用的两种数据格式,它们主要的作用就是用来进行数据的传...

取消回复欢迎 发表评论:

请填写验证码