LearnFromPapers系列——用“模型想象出來的target”來訓練可以提高分類的效果
<center>作者:郭必揚</center>
<center>時間:2020年最后一天</center>
前言:今天是2020年最后一天,這篇文章也是我的SimpleAI公眾號2020年的最后一篇推文,感謝大家一直以來的陪伴和支持,希望SimpleAI曾帶給各位可愛的讀者們一點點的收獲吧~這么特殊的一天,我也來介紹一篇特殊的論文,那就是今年我和組里幾位老師合作的一篇AAAI論文:“Label Confusion Learning to Enhance Text Classification Models”。這篇文章的主要思想是通過構造一個“標簽混淆模型”來實時地“想象”一個比one-hot更好的標簽分布,從而使得各種深度學習模型(LSTM、CNN、BERT)在分類問題上都能得到更好的效果。個人感覺,還是有、意思的。
- 論文標題:Label Confusion Learning to Enhance Text Classification Models
- 會議/期刊:AAAI-21
- 團隊:上海財經大學 信息管理與工程學院 AI Lab
一、主要貢獻
本文的主要貢獻有這么幾點:
- 構造了一個插件--"Label Confusion Model(LCM)",可以在模型訓練的時候實時計算樣本和標簽間的關系,從而生成一個標簽分布,作為訓練的target,實驗證明,這個新的target比one-hot標簽更好;
- 這個插件不需要任何外部的知識,也僅僅在訓練的時候才需要,不會增加模型預測時的時間,不改變原模型的結構。所以LCM的應用范圍很廣;
- 實驗發現LCM還具有出色的抗噪性和抗干擾能力,對于有錯標的數據集,或者標簽間相似度很高的數據集,有更好的表現。
二、問題背景、相關工作
1. 用one-hot來訓練不夠好
本文主要是從文本分類的角度出發的,但文本分類和圖像分類實際上在訓練模式上是類似的,基本都遵循這樣的一個流程:
step 1. 一個深度網絡(DNN,諸如LSTM、CNN、BERT等)來得到向量表示
step 2. 一個softmax分類器來輸出預測的標簽概率分布p
step 3. 使用Cross-entropy來計算真實標簽(one-hot表示)與p之間的損失,從而優化
這里使用cross-entropy loss(簡稱CE-loss)基本上成了大家訓練模型的默認方法,但它實際上存在一些問題。下面我舉個例子:
比如有一個六個類別的分類任務,CE-loss是如何計算當前某個預測概率p相對于y的損失呢:
可以看出,根據CE-loss的公式,只有y中為1的那一維度參與了loss的計算,其他的都忽略了。這樣就會造成一些后果:
- 真實標簽跟其他標簽之間的關系被忽略了,很多有用的知識無法學到;比如:“鳥”和“飛機”本來也比較像,因此如果模型預測覺得二者更接近,那么應該給予更小的loss
- 傾向于讓模型更加“武斷”,成為一個“非黑即白”的模型,導致泛化性能差;
- 面對易混淆的分類任務、有噪音(誤打標)的數據集時,更容易受影響
總之,這都是由one-hot的不合理表示造成的,因為one-hot只是對真實情況的一種簡化。
2. 一些可能的解決辦法
LDL:
既然one-hot不合理,那我們就使用更合理的標簽分布來訓練嘛。比如下圖所示:
如果我們能獲取真實的標簽分布來訓練,那該多好啊。
這種使用標簽的分布來學習模型的方法,稱為LDL(Label Distribution Learning),東南大學耿新團隊專門研究這個方面,大家可以去了解一下。
但是,真實的標簽分布,往往很難獲取,甚至不可獲取,只能模擬。比如找很多人來投票,或者通過觀察進行統計。比如在耿新他們最初的LDL論文中,提出了很多生物數據集,是通過實驗觀察來得到的標簽分布。然而,大多數的現有的數據集,尤其是文本、圖像分類,幾乎都是one-hot的,所以LDL并無法直接使用。
Label Enhancement:
Label Enhancement,機標簽增強技術,則是一類從通過樣本特征空間來生成標簽分布的方法,我在前面的論文解讀中有介紹,這些方法都很有趣。
然而,使用這些方法來訓練模型,都比較麻煩,因為我們需要通過“兩步走”來訓練,第一步使用LE的方法來構造標簽分布,第二步再使用標簽分布來訓練。
Loss Correction:
面對one-hot可能帶來的容易過擬合的問題,有研究提出了Label Smoothing方法:
label smoothing就是把原來的one-hot表示,在每一維上都添加了一個隨機噪音。這是一種簡單粗暴,但又十分有效的方法,目前已經使用在很多的圖像分類模型中了。
這種方法,一定程度上,可以緩解模型過于武斷的問題,也有一定的抗噪能力。但是單純地添加隨機噪音,也無法反映標簽之間的關系,因此對模型的提升有限,甚至有欠擬合的風險。
當然還有一些其他的Loss Correction方法,可以參考我前面的一個介紹。
三、我們的思想&模型設計
我們最終的目標,是能夠使用更加合理的標簽分布來代替one-hot分布訓練模型,最好這個過程能夠和模型的訓練同步進行。
首先我們思考,一個合理的標簽分布,應該有什么樣的性質。
① 很自然地,標簽分布應該可以反映標簽之間的相似性。
比方下面這個例子:
② 標簽間的相似性是相對的,要根據具體的樣本內容來看。
比方下面這個例子,同樣的標簽,對于不同的句子,標簽之間的相似度也是不一樣的:
③ 構造得到的標簽分布,在01化之后應該跟原one-hot表示相同。
啥意思呢,就是我們不能構造出了一個標簽分布,最大值對應的標簽跟原本的one-hot標簽還不一致,我們最終的標簽分布,還是要以one-hot為標桿來構造。
根據上面的思考,我們這樣來設計模型:
使用一個Label Encoder來學習各個label的表示,與input sample的向量表示計算相似度,從而得到一個反映標簽之間的混淆/相似程度的分布。最后,使用該混淆分布來調整原來的one-hot分布,從而得到一個更好的標簽分布。
設計出來的模型結構如圖:
這個結構分兩部分,左邊是一個Basic Predictor,就是各種我們常用的分類模型。右邊的則是LCM的模型。注意LCM是一個插件,所以左側可以更換成任何深度學習模型。
Basic Predictor的過程可以用如下公式表達:
其中就是輸入的文本的通過Input Decoder得到的表示。
則是predicted label distribution(PLD)。
LCM的過程可以表達為:
其中代表label通過Label Encoder得到的標簽表示矩陣,
是標簽和輸入文本的相似度得到的標簽混淆分布,
是真實的one-hot表示,二者通過一個超參數結合再歸一化,得到最終的
,即模擬標簽分布,simulated label distribution(SLD)。
最后,我們使用KL散度來計算loss:
總體來說還是比較簡單的,很好復現,其實也存在更優的模型結構,我們還在探究。
四、實驗&結果分析
1. Benchmark數據集上的測試
我們使用了2個中文數據集和3個英文數據集,在LSTM、CNN、BERT三種模型架構上進行測試,實驗表明LCM可以在絕大多數情況下,提升主流模型的分類效果。
下面這個圖展示了不同水平的α超參數對模型的影響:
從圖中可以看出,不管α水平如何,LCM加成的模型,都可以顯著提高收斂速度,最終的準確率也更高。針對不同的數據集特征,我們可以使用不同的α(比如數據集混淆程度大,可以使用較小的α),另外,論文中我們還介紹了在使用較小α的時候,可以使用early-stop策略來防止過擬合。
而下面這個圖則展示了LCM確實可以學習到label之間的一些相似性關系,而且是從完全隨機的初始狀態開始學到的:
2. 難以區分的數據集(標簽易混淆)
我們構造了幾個“簡單的”和“困難的”數據集,通過實驗標簽,LCM更適合那些容易混淆的數據集:
3. 有噪音的數據集
我們還測試了在不同噪音水平下的數據集上的效果,并跟Label Smoothing方法做了對比,發現是顯著好于LS方法的。
下面這個圖展示了另外一組更細致的實驗結果:
4. 在圖像分類上也有效果
最后,我們在圖像任務上也簡單測試了一下,發現也有效果:
總結: