《智能系统与技术丛书 生成对抗网络入门指南》—3.1.3变分自动编码器
3.1.3 变分自动编码器
相比于普通的自动编码器,变分自动编码器(VAE)才算得上是真正的生成模型。
为了解决前文中叙述的自动编码器存在的不能通过新编码生成数据的问题,VAE在普通的自动编码器上加入了一些限制,要求产生的隐含向量能够遵循高斯分布,这个限制帮助自动编码器真正读懂训练数据的潜在规律,让自动编码器能够学习到输入数据的隐含变量模型。如果说普通自动编码器通过训练数据学习到的是某个确定的函数的话,那么VAE希望能够基于训练数据学习到参数的概率分布。
我们可以通过图3-5看一下VAE的具体实现方法,在编码阶段我们将编码器输出的结果从一个变成两个,两个向量分别对应均值向量和标准差向量。通过均值向量和标准差向量我们可以形成一个隐含变量模型,而隐含编码向量正是通过对于这个概率模型随机采样获得的。最终我们通过解码器将采样获得的隐含编码向量还原成原始图片。
图3-5 VAE实现方法
在实际的训练过程中,我们需要权衡两个问题,第一个是网络整体的准确程度,第二个是隐含变量是否可以很好地吻合高斯分布。对应这两个问题也就形成了两个损失函数:第一个是描述网络还原程度的损失函数,具体的方法是输出数据与输入数据之间的均方距离;第二个是隐含变量与高斯分布相近程度的损失函数。
在这里我们需要介绍一个概念,叫作KL散度(Kullback–Leibler divergence),也可以称作相对熵。KL散度的理论意义在于度量两个概率分布之间的差异程度,当KL散度越高的时候,说明两者的差异程度越大;而当KL散度低的时候,则说明两者的差异程度小。如果两者相同的话,则该KL散度应该为0。这里我们正是采用KL散度来计算隐含变量与高斯分布的接近程度的。
下面的公式代码将两个损失函数相加,由VAE网络在训练过程中决定如何调节这两个损失函数,通过优化这个整体损失函数来使得模型能够达到最优的结果。
generation_loss = mean(square(generated_image - real_image))(3-1)
latent_loss = KL-Divergence(latent_variable, unit_gaussian)(3-2)
loss = generation_loss + latent_loss(3-3)
在使用了VAE以后,生成数据就显得非常简单,我们只需要从高斯分布中随机采样一个隐含编码向量,然后将其输入解码器后即可生成全新的数据。如果将手写数字数据集编码成二维数据,我们可以尝试将二维数据能够生成的数据在平面上展现出来,如图3-6所示是从二维数据(-15, -15)到(15, 15)之间数据点生成的数据,可以看到随着隐含编码的变化,手写数字也会逐渐从左下角的手写数字0逐渐演变成右上角的手写数字1。
图3-6 隐含编码与对应生成之间的关系
当然VAE也存在缺陷,VAE的缺点在于训练过程中最终模型的目的是为了使得输出数据与输入数据的均方误差最小化,这使得VAE其实本质上并非学会了如何生成数据,而是更倾向于生成与真实数据更为接近的数据,甚至于为了数据越接近越好,模型基本会复制真实数据。
为了解决VAE的缺点,也为了让生成模型更加优秀,就让我们请出本书的主角—生成对抗网络(GAN)。让我们来看一下GAN究竟是什么,它是通过什么样的方法来实现生成模型的建立的。
- 点赞
- 收藏
- 关注作者
评论(0)