百度360必应搜狗淘宝本站头条
当前位置:网站首页 > 编程字典 > 正文

条件变分自编码(CVAE)(条件变量是什么意思)

toyiye 2024-07-04 09:15 32 浏览 0 评论

一 条件变分自编码(CVAE)

变分自编码存在一个问题,虽然可以生成一个样本,但是只能输出与输入图片相同类别的样本。虽然也可以随机从符合模型生成的高斯分布中取数据来还原成样本,但是这样的话饿哦们并不知道生成的样本属于哪个类别。条件变分编码则可以解决这个问题,让网络按指定的类别生成样本。

在变分自编码的基础上,再取理解条件编码自编码会很容易。主要的改动是,在训练测试时加入一个one-hot向量,用于表示标签向量。其实就是给编码自编码网络加入一个条件,让网络学习图片时加入标签因素,这样就可以按照指定的标签生成图片。

二 CVAE实例

在编码节点需要在输入端添加标签对应的特征,在解码阶段同样也需要将标签加入输入,这样,再解码的结果向原始的输入样本不断逼近,最终得到的模型会把输入的标签的特征当成MNIST数据的一部分,从而实现通过标签生成指定的图片。

该程序在上一节程序上作了一些修改,代码如下:

'''
条件变分自编码
'''
import tensorflow as tf
import numpy as np
import matplotlib.pyplot as plt
from tensorflow.examples.tutorials.mnist import input_data
mnist = input_data.read_data_sets('MNIST-data',one_hot=True)
print(type(mnist)) #<class 'tensorflow.contrib.learn.python.learn.datasets.base.Datasets'>
print('Training data shape:',mnist.train.images.shape) #Training data shape: (55000, 784)
print('Test data shape:',mnist.test.images.shape) #Test data shape: (10000, 784)
print('Validation data shape:',mnist.validation.images.shape) #Validation data shape: (5000, 784)
print('Training label shape:',mnist.train.labels.shape) #Training label shape: (55000, 10)
train_X = mnist.train.images
train_Y = mnist.train.labels
test_X = mnist.test.images
test_Y = mnist.test.labels
'''
定义网络参数
'''
n_input = 784
n_hidden_1 = 256
n_hidden_2 = 2
n_classes = 10
learning_rate = 0.001
training_epochs = 20 #迭代轮数
batch_size = 128 #小批量数量大小
display_epoch = 3
show_num = 10
x = tf.placeholder(dtype=tf.float32,shape=[None,n_input])
y = tf.placeholder(dtype=tf.float32,shape=[None,n_classes])
#后面通过它输入分布数据,用来生成模拟样本数据
zinput = tf.placeholder(dtype=tf.float32,shape=[None,n_hidden_2])
'''
定义学习参数
'''
weights = {
 'w1':tf.Variable(tf.truncated_normal([n_input,n_hidden_1],stddev = 0.001)),
 'w_lab1':tf.Variable(tf.truncated_normal([n_classes,n_hidden_1],stddev = 0.001)),
 'mean_w1':tf.Variable(tf.truncated_normal([n_hidden_1*2,n_hidden_2],stddev = 0.001)),
 'log_sigma_w1':tf.Variable(tf.truncated_normal([n_hidden_1*2,n_hidden_2],stddev = 0.001)),
 'w2':tf.Variable(tf.truncated_normal([n_hidden_2+n_classes,n_hidden_1],stddev = 0.001)),
 'w3':tf.Variable(tf.truncated_normal([n_hidden_1,n_input],stddev = 0.001))
 }
biases = {
 'b1':tf.Variable(tf.zeros([n_hidden_1])),
 'b_lab1':tf.Variable(tf.zeros([n_hidden_1])),
 'mean_b1':tf.Variable(tf.zeros([n_hidden_2])),
 'log_sigma_b1':tf.Variable(tf.zeros([n_hidden_2])),
 'b2':tf.Variable(tf.zeros([n_hidden_1])),
 'b3':tf.Variable(tf.zeros([n_input]))
 }
'''
定义网络结构
'''
#第一个全连接层是由784个维度的输入样->256个维度的输出
h1 = tf.nn.relu(tf.add(tf.matmul(x,weights['w1']),biases['b1']))
#输入标签
h_lab1 = tf.nn.relu(tf.add(tf.matmul(y,weights['w_lab1']),biases['b_lab1']))
#合并
hall1 = tf.concat([h1,h_lab1],1)
#第二个全连接层并列了两个输出网络
z_mean = tf.add(tf.matmul(hall1,weights['mean_w1']),biases['mean_b1'])
z_log_sigma_sq = tf.add(tf.matmul(hall1,weights['log_sigma_w1']),biases['log_sigma_b1'])
#然后将两个输出通过一个公式的计算,输入到以一个2节点为开始的解码部分 高斯分布样本
eps = tf.random_normal(tf.stack([tf.shape(h1)[0],n_hidden_2]),0,1,dtype=tf.float32)
z = tf.add(z_mean,tf.multiply(tf.sqrt(tf.exp(z_log_sigma_sq)),eps))
#合并
zall = tf.concat([z,y],1) #None x 12
#解码器 由12个维度的输入->256个维度的输出
h2 = tf.nn.relu(tf.matmul(zall,weights['w2']) + biases['b2'])
#解码器 由256个维度的输入->784个维度的输出 即还原成原始输入数据
reconstruction = tf.matmul(h2,weights['w3']) + biases['b3']
#这两个节点不属于训练中的结构,是为了生成指定数据时用的
zinputall = tf.concat([zinput,y],1)
h2out = tf.nn.relu(tf.matmul(zinputall,weights['w2']) + biases['b2'])
reconstructionout = tf.matmul(h2out,weights['w3']) + biases['b3']
'''
构建模型的反向传播
'''
#计算重建loss
#计算原始数据和重构数据之间的损失,这里除了使用平方差代价函数,也可以使用交叉熵代价函数 
reconstr_loss = 0.5*tf.reduce_sum((reconstruction-x)**2)
print(reconstr_loss.shape) #(,) 标量
#使用KL离散度的公式
latent_loss = -0.5*tf.reduce_sum(1 + z_log_sigma_sq - tf.square(z_mean) - tf.exp(z_log_sigma_sq),1)
print(latent_loss.shape) #(128,)
cost = tf.reduce_mean(reconstr_loss+latent_loss)
#定义优化器 
optimizer = tf.train.AdamOptimizer(learning_rate).minimize(cost)
num_batch = int(np.ceil(mnist.train.num_examples / batch_size))
'''
开始训练
'''
with tf.Session() as sess:
 sess.run(tf.global_variables_initializer())
 
 print('开始训练')
 for epoch in range(training_epochs):
 total_cost = 0.0
 for i in range(num_batch):
 batch_x,batch_y = mnist.train.next_batch(batch_size) 
 _,loss = sess.run([optimizer,cost],feed_dict={x:batch_x,y:batch_y})
 total_cost += loss
 
 #打印信息
 if epoch % display_epoch == 0:
 print('Epoch {}/{} average cost {:.9f}'.format(epoch+1,training_epochs,total_cost/num_batch))
 
 print('训练完成')
 
 #测试
 print('Result:',cost.eval({x:mnist.test.images,y:mnist.test.labels}))
 #数据可视化 根据原始图片生成自编码数据 
 reconstruction = sess.run(reconstruction,feed_dict = {x:mnist.test.images[:show_num],y:mnist.test.labels[:show_num]})
 plt.figure(figsize=(1.0*show_num,1*2)) 
 for i in range(show_num):
 #原始图像
 plt.subplot(2,show_num,i+1) 
 plt.imshow(np.reshape(mnist.test.images[i],(28,28)),cmap='gray') 
 plt.axis('off')
 
 #变分自编码器重构图像
 plt.subplot(2,show_num,i+show_num+1)
 plt.imshow(np.reshape(reconstruction[i],(28,28)),cmap='gray') 
 plt.axis('off')
 plt.show()
 
 
 '''
 高斯分布取样,根据标签生成模拟数据
 ''' 
 z_sample = np.random.randn(show_num,2)
 reconstruction = sess.run(reconstructionout,feed_dict={zinput:z_sample,y:mnist.test.labels[:show_num]}) 
 plt.figure(figsize=(1.0*show_num,1*2)) 
 for i in range(show_num):
 #原始图像
 plt.subplot(2,show_num,i+1) 
 plt.imshow(np.reshape(mnist.test.images[i],(28,28)),cmap='gray') 
 plt.axis('off')
 
 #根据标签成成模拟数据
 plt.subplot(2,show_num,i+show_num+1)
 plt.imshow(np.reshape(reconstruction[i],(28,28)),cmap='gray') 
 plt.axis('off')
 plt.show()
 

上面第一幅图是根据原始图片生成的自编码数据,第一行为原始数据,第二行为自编码数据,该数据仍然保留一些原始图片的特征。

第二幅图片是利用样本数据的标签和高斯分布之z_sample一起生成的模拟数据,我们可以看到通过标签生成的数据,已经彻底学会了样本数据的分布,并生成了与输入截然不同但带有相同意义的数据。

相关推荐

为何越来越多的编程语言使用JSON(为什么编程)

JSON是JavascriptObjectNotation的缩写,意思是Javascript对象表示法,是一种易于人类阅读和对编程友好的文本数据传递方法,是JavaScript语言规范定义的一个子...

何时在数据库中使用 JSON(数据库用json格式存储)

在本文中,您将了解何时应考虑将JSON数据类型添加到表中以及何时应避免使用它们。每天?分享?最新?软件?开发?,Devops,敏捷?,测试?以及?项目?管理?最新?,最热门?的?文章?,每天?花?...

MySQL 从零开始:05 数据类型(mysql数据类型有哪些,并举例)

前面的讲解中已经接触到了表的创建,表的创建是对字段的声明,比如:上述语句声明了字段的名称、类型、所占空间、默认值和是否可以为空等信息。其中的int、varchar、char和decimal都...

JSON对象花样进阶(json格式对象)

一、引言在现代Web开发中,JSON(JavaScriptObjectNotation)已经成为数据交换的标准格式。无论是从前端向后端发送数据,还是从后端接收数据,JSON都是不可或缺的一部分。...

深入理解 JSON 和 Form-data(json和formdata提交区别)

在讨论现代网络开发与API设计的语境下,理解客户端和服务器间如何有效且可靠地交换数据变得尤为关键。这里,特别值得关注的是两种主流数据格式:...

JSON 语法(json 语法 priority)

JSON语法是JavaScript语法的子集。JSON语法规则JSON语法是JavaScript对象表示法语法的子集。数据在名称/值对中数据由逗号分隔花括号保存对象方括号保存数组JS...

JSON语法详解(json的语法规则)

JSON语法规则JSON语法是JavaScript对象表示法语法的子集。数据在名称/值对中数据由逗号分隔大括号保存对象中括号保存数组注意:json的key是字符串,且必须是双引号,不能是单引号...

MySQL JSON数据类型操作(mysql的json)

概述mysql自5.7.8版本开始,就支持了json结构的数据存储和查询,这表明了mysql也在不断的学习和增加nosql数据库的有点。但mysql毕竟是关系型数据库,在处理json这种非结构化的数据...

JSON的数据模式(json数据格式示例)

像XML模式一样,JSON数据格式也有Schema,这是一个基于JSON格式的规范。JSON模式也以JSON格式编写。它用于验证JSON数据。JSON模式示例以下代码显示了基本的JSON模式。{"...

前端学习——JSON格式详解(后端json格式)

JSON(JavaScriptObjectNotation)是一种轻量级的数据交换格式。易于人阅读和编写。同时也易于机器解析和生成。它基于JavaScriptProgrammingLa...

什么是 JSON:详解 JSON 及其优势(什么叫json)

现在程序员还有谁不知道JSON吗?无论对于前端还是后端,JSON都是一种常见的数据格式。那么JSON到底是什么呢?JSON的定义...

PostgreSQL JSON 类型:处理结构化数据

PostgreSQL提供JSON类型,以存储结构化数据。JSON是一种开放的数据格式,可用于存储各种类型的值。什么是JSON类型?JSON类型表示JSON(JavaScriptO...

JavaScript:JSON、三种包装类(javascript 包)

JOSN:我们希望可以将一个对象在不同的语言中进行传递,以达到通信的目的,最佳方式就是将一个对象转换为字符串的形式JSON(JavaScriptObjectNotation)-JS的对象表示法...

Python数据分析 只要1分钟 教你玩转JSON 全程干货

Json简介:Json,全名JavaScriptObjectNotation,JSON(JavaScriptObjectNotation(记号、标记))是一种轻量级的数据交换格式。它基于J...

比较一下JSON与XML两种数据格式?(json和xml哪个好)

JSON(JavaScriptObjectNotation)和XML(eXtensibleMarkupLanguage)是在日常开发中比较常用的两种数据格式,它们主要的作用就是用来进行数据的传...

取消回复欢迎 发表评论:

请填写验证码