变分自编码器(VAE)学习

2019/02/23 Machine Learning

一直希望有时间总结一下VAE这个模型,细节还是挺多的,总是忘,记录一遍可以加深记忆,岁数大了就要想各种各样的办法帮助自己提高记忆力啊。不求特别严谨,但求合理易懂。

AutoEncoder(AE)

探究VAE之前,我们得先知道AE是什么。和PCA一样,AE发明的目的也是用于降维。他们的区别是,PCA是线性降维,而AE由于在神经网络中堆叠了非线性的激活函数因此可以做到非线性的降维。下图就是一个简单的AE的结构,只要满足输入输出维度相同,隐变量维度小于观测变量维度,这样的神经网络就可以称为一个AE。

AE的的训练就是去最小化原数据和生成数据的差异,最终得到的模型即可用来对相似的数据进行压缩(降维),得到一系列数据的特征(隐变量序列)。但是AE除了用来压缩(对应encoder),还可以用来生成新数据(对应decoder)。单拿出来decoder,为每个特征赋上不同的值,就可以还原出一个在这些特征上有差异的原数据。例如下图,原图片进行压缩后得到笑容,性别,头发等特征。为这些隐变量赋予不同的数值,既可以生成一张与原图较为相似的图片。但因为降维过程是有损的,所以是无法完全还原出原图的。

但是这里有一个问题,这个值不是随意给的。因为经过了非线性的变化,每个隐变量的分布都是未知的,只有按照分布取值才能得到有意义的结果,隐变量也不太可能是线性的,比如某一个隐变量用来描述头发长度,短头发映射的结果是3,长头发映射的结果是9,但如果输入6得到却不一定是中等长度的头发。下面左图是一张数字的图片encode到2维隐变量的结果,二维的分布是可以可视化画出来的。当在红框部分抽样时,decode出的图片是右图,会发现左上角的图片就看起来没什么意义,因为并没有原始图片降维后映射到那个位置,所以生成的图片较差。如果选择左图右下空白的部分抽样生成图片,生成的图片一定没什么意义。所以要是想根据AE来生成新数据,需要了解 $p(z)$ 的分布,在高维的情况下是很难做到的。VAE就是来解决这个问题的。

VAE的思路

VAE的思路就是,既然隐变量分布未知,那我就强迫它符合高斯分布(为什么选高斯,见Q&A)。这样一来,模型就变为了一个概率模型,即考虑了隐变量的分布。此时我们想生成一个新数据时,只需要给decoder一个标准正态分布的采样,就可以得到我们想要的图片,而不需要给他一张原始图片先编码了。

那么怎么才能让隐变量符合高斯分布呢?VAE的思路是,既然无法限制分布,那么就先生成分布,再从分布中采样作为隐变量。对于高斯分布,encoder的输出就是一个mean和一个std。对于k维的隐变量,encoder会给每一维都训练一个标准高斯分布,即输出k对mean和std,所以其实隐变量向量z符合的是一个协方差矩阵为对角阵的多维高斯分布。

需要注意的地方是,因为训练时需要反向传播来求解参数,而采样的过程并不能求导,所以VAE用到了一个叫 Reparameterization Trick 的技巧来使这个过程可以求导。具体实现方法是,把采样的过程隔离出去,先从标准正态分布采样一个 $\varepsilon$ 然后乘以 $\sigma$ 再加上 $\mu$ 就相当于在原分布采样了。另外因为encoder的输出值域是负无穷到正无穷的,为了使 $\sigma$ 为正数,这里乘以的其实是 $exp(\sigma)$,神经网络会自适应的输出合理的 $\mu$ 和 $\sigma$ (其实是 $\log \sigma$ )。

现在我们已经有了VAE的结构了,接下来就让它像AE那样去直接根据重构误差训练网络可以吗?答案是不行的。因为神经网络和人一样都是有惰性的,如果没有限制,网络不会把自己搞的很复杂,训练时把 $\sigma$ 全置为0就行了,分布变成了确定的数值,VAE等同于退化为了AE。所以除了重构误差尽可能小以外,VAE还要保证z的分布尽可能接近标准正态分布,在实现上会用KL散度来约束。这样做的结果就是,为了使重构误差足够小,模型尽可能使z的取值接近某一个最佳值,提高了模型的准确率;而为了使z接近标准正态分布,模型每次又取不到最佳值,相当于对decoder的输入加入了噪声,提高了模型的泛化能力。两种策略互相制约,最终达到一个trade-off,使模型拥有了按需生成数据的能力。

关于decoder的输出,也就是 $p(x \vert z)$ ,其实我们也可以把它假设为高斯分布(实际也是这么做的),那么使其似然最大就是都取高斯分布的均值作为输出值,理论上和encoder输入的原数据的重构误差就是最小的。所以decoder也可以训练出一个 $\mu$ 和 $\sigma$ ,进而给出X的置信区间去做区间估计。

VAE的理论推导

上述思路只是直观上的想法,仍需要严谨的数学推导去证明其可行性以及明确优化目标。其实说不好VAE的产生到底是由理论推动的还是由直观思路推动的,其实是相辅相成的,总之VAE是一个从各个角度解释都很漂亮的模型。

我们的目标是根据x的某些特征z生成一个x。x是可观测的变量,而z是隐变量,所以我们需要反过来根据x去找出z的分布,也就是去求 $p(z \vert x)$ 。根据贝叶斯估计有

显然在高维情况下这个积分没有解析形式,$p(x)$ 是求不出来的。而且如果能把 $p(x)$ 求出来,也就不需要什么隐变量了,分布直接就是最完美的模型了。有两种方法解决这个问题,一个是用MCMC去模拟估计出这个积分,另一个就是变分推断。

VAE就是采用了变分推断的思想去估计 $p(z \vert x)$ 的。假设存在一个可以用参数描述的分布 $q(z \vert x)$ ,使其与 $p(z \vert x)$ 非常相似,我们就可以用它来代替 $p(z \vert x)$ 这个分布。这种相似是用KL散度来衡量的,所以我们的目标变成了最小化 $KL(q(z \vert x) \parallel p(z \vert x))$ 。下面推导一下这个KL散度等于什么。

最终得到

因为x是给定的观测值,所以这个等式的左边是一个常数。我们的目标是最小化右边第一项的KL散度,所以就变成了要最大化右边的第二项。我们称这项为变分下界,一般的变分推断的目标都是去最大化这个变分下界。通常情况下,这个变分下界都是很难去最大化的,而VAE通过巧妙的变化,把它变得易于理解。那么继续来看一下这个变分下界是什么。

最终我们的目标变成了最大化下面的式子(变分下界)。

先看第一项,最大化变分下界需要最大化第一项,而最大化第一项其实就是最小化重建误差。为什么呢?设 $\hat x$ 是 $z$ 通过decoder后重建得到的新数据,因为decoder是一个神经网络,$z$ 和 $\hat x$ 有着对应关系,所以 $p(x \vert z)$ 相当于 $p(x \vert \hat x)$ 。又因为是关于 $q(z \vert x)$ 的期望,我们假设了z的分布为高斯分布,所以带入高斯分布解析式后最大化的东西就变成了 $e^{-\vert x - \hat x \vert^2}$ ,加上log就是最大化 $-\vert x - \hat x \vert^2$ ,也就是最小化 $\vert x - \hat x \vert^2$ ,所以第一项可以写成最小化重建误差 $\mathcal{L}(x, \hat x)$。

再看第二项,最大化变分下界需要最小化第二项。这个KL散度的意义也很明显,就是让我们假设的分布q近似于z的真实分布,与我们最初的目标相符合。由于我们假设z每个维度的分布都是独立的标准正态分布,所以也可以分开写成 $\sum_j{}KL(q_j(z \vert x) \parallel p(z))$ 。

最终我们得到的最小化目标,也就是损失函数为

我们知道 $x$ 是原数据, $z$ 是中间的隐变量,$\hat x$是重构数据,那么一个常规的想法就是用神经网络去拟合他们之间的映射。于是我们用encoder来拟合 $q(z \vert x)$ ,用decoder来拟合 $p(x \vert z)$,隐变量z正好作为中间层。如下图所示。

这与上一小节的思路完全一致,也从理论上印证了这个方法的可行性。

可视化的结果

为了更好地理解VAE,本文和AE一样举一个二维隐变量可视化的例子。当只考虑重建误差时,VAE退化为AE,虽然类别区分了出来,但隐变量分布不可知。当只考虑z的分布时,隐变量完全丢失了原数据有意义的信息。当两项同时考虑时,隐变量的分布趋近于标准正态分布,又很明显的区分出了不同数据的类别。

下图是MINIST手写数字数据集使用VAE训练后生成的结果。可以看出来VAE学习到了手写数字图片的构造,并能在图中找到每种数字之间明显的变化趋势。通过改变隐变量完全可以按需生成任意想要的数字图片。

Q&A

为什么要使 $p(z)$ 符合高斯分布,其他分布不行吗?

其实这里用任何分布都是可以的,但是高斯分布的好处是:

  1. 大部分特征都符合正态分布,且正态分布可以拟合任意分布
  2. 参数有两个,且有实际意义
  3. 均值点即为原始数据的最佳选择

参考

https://www.jeremyjordan.me/variational-autoencoders/

https://zhuanlan.zhihu.com/p/55557709

https://www.cnblogs.com/fxjwind/p/9099931.html

https://zhuanlan.zhihu.com/p/27549418

https://www.youtube.com/watch?v=uaaqyVS9-rM&feature=youtu.be

https://www.jianshu.com/p/3085936e8bc3

Search

    Table of Contents