Seq2seq 由Sutskever et al.(2014, Google) 於 Sequence to Sequence Learning with Neural Networks 提出,結合編碼與解碼的概念,使LSTM可以處理更廣泛的自然語言任務,是現代NLP的經典架構。
資料集
Multi30K,是一個英德語翻譯的資料集。
總共29000個訓練語句對、1000個驗證語句對。
網路架構
Seq2Seq解決的問題是兩個不同長度序列間的對應,與CTC loss解決的點類似。
模型架構由encoder先將來源序列的資訊編碼成LSTM的內部狀態向量,再透過decoder取用此內部狀態初始化,循序生成新的句子。兩個模型可傳遞一組共用的內部訊息,一個負責轉換來源序列,一個負責解開來源序列與生成目標序列,達到自編碼的效果。
隨後Bahdanau et al (2015) 和 Luong et al (2015) 發表不同類型的注意力機制與Seq2Seq的自編碼器結構結合的方式,注意力機制的角色是將注意力放在來源序列中較為重要的位置,達到對齊不同時間步預測與來源序列中較為重要的部分。
其中Luong的方式分為Global attention和Local attention,前者使用整個來源序列計算注意力權重,後者則有估計的來源序列位置,只對估計位置周圍取樣作為高斯核函數與注意力權重做加權調整。這點與Soft NMS相似。
解碼器推測下一個預測時通常使用該時間步預測中機率最大的token作為預測結果,只考慮當下最佳解的行為符合貪婪演算法的思想,被稱為貪婪搜尋。
而另一種方法是採用beam search(束集搜尋),每個時間步預測時都保留top k的預測,進入下一步時則以利用此k個預測計算條件利率,找出下一個時間步的top k個預測。以此類推,可得到k個束集結果,再從其中挑出最好的結果,屬於BFS(寬度優先搜尋)的一種,好處是避免像貪婪演算法只收束在top 1的結果,缺點是會多出k-1倍的計算量。
訓練
採用 Cross Entropy loss,並使用teacher forcing作為強迫提示。
評估
訓練10個epoch,在驗證集的表現尚可,會出現重複token且語句不通順。
結語
- 嘗試論文中的技巧:逆向來源序列可減短與目標序列的依賴距離,但結果與正向無明顯差異,僅第一個epoch較快收斂。
- Local attention占用較少GPU但計算步驟多較耗時,Global attention較快但占用較多GPU。Pytorch的運算優化主要來自算子本身,一旦Python的算子變得瑣碎,跟GPU的溝通量變多就會浪費很多資源。
- 可以再修改的部分:以beam search做解碼,目前貪婪搜尋在驗證集的表現不差。