VAE的前向过程及核心代码
VAE的前向过程及核心代码
ELBO 目标函数:
代入建模:
对 求导的思路
对
对
求导的思路
第一项(重建损失):
由于期望的分布本身依赖于参数
- “变换之前”:我们从一个由
参数化的分布中直接采样潜在变量 ,即 ,这个过程不可导 - “变换之后”:我们引入
,通常从标准正态分布中采样,即 。然后,将 转换为潜在变量 : - 求导问题:通过重参数化,原来的期望
就变成了对 的期望 - 可以将梯度符号
“穿过” 期望符号 ,
第二项(KL 散度):
这一项可以进行显式计算,即当
计算思路
最后计算的式子为:
应用蒙特卡洛(Monte Carlo, MC)方法:
- 从标准正态分布
中,采样出若干个随机变量 - 利用采样得到的
,通过重参数化公式 计算得到潜在变量 - 计算近似 ELBO:
- 通过反向传播对
和 求导,并使用梯度下降法来更新编码器和解码器的参数
前向过程
输入数据 (Input Data)
- 将一个原始数据样本
输入到模型中
- 将一个原始数据样本
编码器(Encoder)
- 编码器网络将
作为输入 - 输出:潜在分布(Latent Distribution)的参数
和
- 编码器网络将
重参数化采样(Reparameterization Trick)
- 从一个标准正态分布
中采样一个随机变量 - 计算得到潜在变量
:
- 从一个标准正态分布
解码器(Decoder)
- 将采样得到的潜在变量
作为输入 - 输出:重建后的数据
- 即
- 将采样得到的潜在变量
计算损失(Calculate Loss)
a. 重建损失(Reconstruction Loss):
- 衡量原始输入
和重建输出 之间的差异 - 通常使用二元交叉熵(BCE)或均方误差(MSE)
- 对应于 ELBO 的第一项:
- 衡量原始输入
b. KL 散度损失(KL Divergence Loss):
- 衡量编码器输出的潜在分布
与我们预设的先验分布 之间的差异 - 使用解析公式直接计算,无需采样
- 对应于 ELBO 的第二项:
- 衡量编码器输出的潜在分布
最终损失(Total Loss)
- 将重建损失和 KL 散度损失相加,得到用于优化的最终损失值
代码
1 | class VAE(nn.Module): |
All articles on this blog are licensed under CC BY-NC-SA 4.0 unless otherwise stated.








