長短期記憶是由Hochreiter et al.(1997, 慕尼黑大學) 在 LONG SHORT-TERM MEMORY 提出,主要用於序列資料的處理,曾經獲得廣泛的使用與研究,所屬的遞歸類神經網路與卷積類神經網路並列深度學習領域的代表,直到注意力機制出現後逐漸淡出。
網路概念
RNN層由一個RNN運算元執行運算,在每個時間步接收新的資訊x_t,將該資訊與前一步內部狀態h_t-1整合並經激活函數輸出為下一步內部狀態h_t,此內部狀態再重複利用於每個對應的時間步,達到對順序特性的建模。
經由以上運算方式,RNN達到以下幾項優勢
- 不同時間步的參數共享,這與CNN的局部區域參數共享相似,使得模型大小較小且GPU用量小
- 過往時間步的資訊可融合到下一時間步,達到自迴歸(Autoregressive)的建模表達
- 利用Backpropagation through time(BPTT)完成參數更新,計算方式簡單
不過優點也是缺點
- 參數共享導致記憶有限,可處理的時間長度難以延長
- 自迴歸特性需要依賴前一次計算的結果,計算複雜度由時間長度決定,導致迴圈等順序性計算的負擔。以前有過CPU比GPU更適合運行的時期
- Backpropagation的深度隨著時間長度的延伸而影響梯度的回推,會有梯度爆炸與梯度消失的問題,常見的範例是多層sigmoid回推使得梯度越來越小的問題。pytorch當下(2024)使用的是tanh和relu。
LSTM則進一步加強RNN的運算,將外部資訊x_t和上一步內部資訊h_t-1的整合結果進行多種運算組合,生成另一個短期內部狀態c_t,也因此LSTM屬於有多個內部狀態傳遞的RNN變形。
如果不看short-term區塊的內部計算,外部IO看來是將RNN的整合結果與c_t-1進行運算後的結果相乘。
short-term區塊中多個不同的gate機制可參考影片或其他文章,參考中的論文也做了數學上的統一,儘管解釋上偏向訊號與系統的架構。
網路架構
先將 token id 轉成嵌入向量,進入LSTM擷取序列特徵,過分類器得到迴歸結果。
資料集
AG News,https://paperswithcode.com/dataset/ag-news,AG新聞的多分類資料集。
評估
混淆矩陣顯示分類結果不錯,達到測試91%準確度。模型大小為37MiB。
筆記
- 模型大小可達數十MB,超過一半以上的參數來源是 embedding layer
- 分類器的輸入是hidden state,而不是output。
- pytorch的LSTM分成cell和layer的包裝,cell僅做概念圖中的算子功能,layer則做較完整的運算,包括雙向和多層的設定,避免多次與GPU的互動。這邊解釋一下pytorch中LSTM的輸出是一個tuple,包含最後一層的全部時間步輸出o_t和所有層的最後一個時間步輸出h_t跟c_t。
結語
本文介紹如何應用簡單的long short-term memory classifier。實作代碼可參考以下連結。
參考