BERT(Bidirectional Encoder Representations from Transformers)的MLM(Masked Language Model)損失是這樣設(shè)計的:在訓(xùn)練過程中,BERT隨機地將輸入文本中的一些單詞替換為一個特殊的[MASK]標(biāo)記,然后模型的任務(wù)是預(yù)測這些被掩蓋的單詞。具體來說,它會預(yù)測整個詞匯表中每個單詞作為掩蓋位置的概率。
MLM損失的計算方式是使用交叉熵?fù)p失函數(shù)。對于每個被掩蓋的單詞,模型會輸出一個概率分布,表示每個可能的單詞是正確單詞的概率。交叉熵?fù)p失函數(shù)會計算模型輸出的概率分布與真實單詞的分布(實際上是一個one-hot編碼,其中正確單詞的位置是1,其余位置是0)之間的差異。
具體來說,如果你有一個詞匯表大小為V,對于一個被掩蓋的單詞,模型會輸出一個V維的向量,表示詞匯表中每個單詞的概率。如果y是一個one-hot編碼的真實分布,而p是模型預(yù)測的分布,則交叉熵?fù)p失可以表示為(用于衡量模型預(yù)測概率分布與真實標(biāo)簽概率分布之間的差異):
其中:
-
表示損失函數(shù)的值
-
表示類別的數(shù)量
-
是第
個類別的真實標(biāo)簽,通常為0或1
-
是模型預(yù)測第
個類別的概率
-
表示自然對數(shù)
-
表示對所有類別求和
在這個公式中,是真實分布中的第i個元素,而
是模型預(yù)測的分布中的第i個元素。由于y是one-hot編碼的,所以除了正確單詞對應(yīng)的位置為1,其余位置都是0,這意味著上面的求和實際上只在正確單詞的位置計算。
在實際操作中,為了提高效率,通常不會對整個詞匯表進(jìn)行預(yù)測,而是使用采樣技術(shù),如負(fù)采樣(negative sampling)或者層次softmax(hierarchical softmax),來減少每個訓(xùn)練步驟中需要計算的輸出數(shù)量。