Pytorch實作系列 — GRU

mz bai
4 min readSep 8, 2024

--

Cho et al. 在 (2014, 蒙特羅)提出 Learning Phrase Representations using RNN Encoder-Decoder for Statistical Machine Translation,該網路是著名的RNN變形之一,因其簡化的運算和與LSTM相當的實驗表現,成為序列建模中不錯的選擇。

對LSTM不熟的可先參考

網路概念

GRU如同LSTM是一種RNN的變形,當RNN隨著時間步增加會出現訊息傳遞不穩定的問題,LSTM提出以內部狀態調控訊息傳遞,而GRU提出較為簡化的訊息調控方式,沒有output gate(o)控制輸出,forget gate(f)和input gate(i)變為reset gate(r)和update gate(z)控制新舊記憶的混和。

LSTM and GRU comparison in paper[1412.3555]

網路結構

GRU有著類似Highway network的閘道設計,軟加總當前記憶與過去記憶作為調控機制,並簡化LSTM的內部狀態,更直接的使用前一步的hidden state簡化運算。

雖然圖變得有點難畫~~ 多了跳線跟 1-z的路線

GRU architecture

資料集

AG News,AG是作者的名字,收集超過2000個新聞來源的多分類資料集,在torchtext共收錄四類,全球、運動、商業、科學。

評估

測試準確度達到91%,模型大小為31MB。

confusion matrix

與LSTM在準確度和訓練速度的性能相當。這點與論文中的比較結果相似。

在相同參數下,GRU的權重數量是LSTM的20%,不過因為embedding layer仍佔模型絕大部分權重,所以模型大小相當。

實作

參考

--

--

mz bai

Present data engineer, former data analyst, Kaggle player, loves data modeling.