Pytorch實作系列— LSTM

mz bai
Dec 3, 2020

--

長短期記憶是由Hochreiter et al.(1997, 慕尼黑大學) 在 LONG SHORT-TERM MEMORY 提出,主要用於序列資料的處理,曾經獲得廣泛的使用與研究,所屬的遞歸類神經網路與卷積類神經網路並列深度學習領域的代表,直到注意力機制出現後逐漸淡出。

網路概念

RNN概念圖

RNN層由一個RNN運算元執行運算,在每個時間步接收新的資訊x_t,將該資訊與前一步內部狀態h_t-1整合並經激活函數輸出為下一步內部狀態h_t,此內部狀態再重複利用於每個對應的時間步,達到對順序特性的建模。

經由以上運算方式,RNN達到以下幾項優勢

  1. 不同時間步的參數共享,這與CNN的局部區域參數共享相似,使得模型大小較小且GPU用量小
  2. 過往時間步的資訊可融合到下一時間步,達到自迴歸(Autoregressive)的建模表達
  3. 利用Backpropagation through time(BPTT)完成參數更新,計算方式簡單

不過優點也是缺點

  1. 參數共享導致記憶有限,可處理的時間長度難以延長
  2. 自迴歸特性需要依賴前一次計算的結果,計算複雜度由時間長度決定,導致迴圈等順序性計算的負擔。以前有過CPU比GPU更適合運行的時期
  3. Backpropagation的深度隨著時間長度的延伸而影響梯度的回推,會有梯度爆炸與梯度消失的問題,常見的範例是多層sigmoid回推使得梯度越來越小的問題。pytorch當下(2024)使用的是tanh和relu。
LSTM概念圖

LSTM則進一步加強RNN的運算,將外部資訊x_t和上一步內部資訊h_t-1的整合結果進行多種運算組合,生成另一個短期內部狀態c_t,也因此LSTM屬於有多個內部狀態傳遞的RNN變形。

如果不看short-term區塊的內部計算,外部IO看來是將RNN的整合結果與c_t-1進行運算後的結果相乘。

short-term區塊中多個不同的gate機制可參考影片或其他文章,參考中的論文也做了數學上的統一,儘管解釋上偏向訊號與系統的架構。

動畫參考

網路架構

先將 token id 轉成嵌入向量,進入LSTM擷取序列特徵,過分類器得到迴歸結果。

資料集

AG News,https://paperswithcode.com/dataset/ag-news,AG新聞的多分類資料集。

評估

混淆矩陣顯示分類結果不錯,達到測試91%準確度。模型大小為37MiB。

筆記

  1. 模型大小可達數十MB,超過一半以上的參數來源是 embedding layer
  2. 分類器的輸入是hidden state,而不是output。
  3. pytorch的LSTM分成cell和layer的包裝,cell僅做概念圖中的算子功能,layer則做較完整的運算,包括雙向和多層的設定,避免多次與GPU的互動。這邊解釋一下pytorch中LSTM的輸出是一個tuple,包含最後一層的全部時間步輸出o_t和所有層的最後一個時間步輸出h_t跟c_t。
RNN輸出解釋圖 in PyTorch

結語

本文介紹如何應用簡單的long short-term memory classifier。實作代碼可參考以下連結。

參考

--

--

mz bai

Present data engineer, former data analyst, Kaggle player, loves data modeling.