Pytorch實作系列 — VAE

mz bai
Dec 1, 2020

--

變分自編碼機是由Kingma et al.(2013, 阿姆斯特丹大學)在Auto-Encoding Variational Bayes提出,以貝氏推論為網路設計,達到變分推論的效果,可作為一種生成模型的架構。

資料集

MNIST是一個手寫數字辨識集。

網路

概念上是取先取得圖片的潛在特徵,作再參數化,經由常態分佈抽樣,解碼後生成新圖像。

損失函數

採用 mean square error 作為重建損失,加上Kullback-liebler 散度。前者使樣本更接近原本的圖片,後者為生成樣本增加噪音。

訓練

如同分類任務。

評估

生成的圖像不是很清晰,但有各種數字。

筆記

  1. CPU的訓練速度和GPU差不多。
  2. 以全連接層實現的模型大小可以少於1MB。
  3. 使用BCE loss的效果較為黑白分明,MSE loss 則有灰色部分。
  4. 增加KLD權重,會使圖片更加模糊。

結語

本文介紹如何做出簡單的variational encoder。實作代碼可參考以下連結。

參考

https://shenxiaohai.me/2018/10/20/pytorch-tutorial-advanced-02/

--

--

mz bai
mz bai

Written by mz bai

Present data engineer, former data analyst, Kaggle player, loves data modeling.