前邊介紹了TensorFlow的基本操作和Keras的高層接口:
tf2.0學習(一)——基礎知識
tf2.0學習(二)——進z階知識
tf2.0學習(三)——神經網絡
tf2.0學習(四)——反向傳播算法
tf2.0學習(五)——Keras高層接口
下面我們接好一下在訓練過程中經常要面對的一個問題,過擬合,以及在TensorFlow這個框架中如何更好的處理這個問題。
6.0 簡介
機器學習的主要目的,是通過訓練集學習到數據的真實模型,從而在未見過的測試集上能有良好的表現,這種能力叫做模型的泛化能力。通常來說,訓練集和測試集都采樣自某個相同的數據分布p(x)。采樣到的樣本是相互獨立的,但又來自同一個分布,我們把這種假設叫做獨立同分布假設(簡稱:i.i.d)。
模型的表達能力,也叫做模型的容量。當模型的表達能力偏弱時,會導致無法充分學習到數據的特征,從而導致模型性能很差,這時候模型在訓練集和測試集上的表現都很差。當模型的表達能力過強時,又會導致模型學習過于充分,甚至學到了訓練集中的噪聲,這時候模型在訓練集上表現很好,但在測試集上的表現很差。
6.1 模型的容量
通俗的講,模型的容量或表達能力,就是模型擬合復雜函數的能力。一種體現模型容量的指標叫做模型的假設空間,即模型可以表示的函數集的大小。假設空間越大越完備,就越有可能從假設空間中搜索到能夠擬合真實數據的函數,相反,如果假設空間很小,就很難找到擬合真實數據的函數。
假設一數據集采樣自如下分布:
該數據集引入了一些觀測誤差,如下圖小圓點所示。如果只搜索1次多項式的模型空間,那么最多能擬合出一條直線來,效果很差。如果搜索空間增加到3次多項式函數,此時假設空間明顯大于1次多項式的情況,此時能擬合出一條曲線,效果能有些提升。如果繼續增加多項式的冪次,那么假設空間越來越大,搜索的范圍也越來越大,就約有可能找到擬合效果更好的模型。
但是過大的搜索空間,無疑會增加模型的搜索難度和計算代價。實際上在有限的計算資源下,較大的搜索空間并不一定能找出更好的函數模型。相反,隨著假設空間中可能存在表達能力過強的模型,學習到了訓練集中的噪聲數據,從而傷害了模型的泛化能力。因此在實際情況中,往往根據具體任務,選擇合適的假設空間的模型。
6.2 過擬合與欠擬合
由于真實數據的分布往往是未知又復雜的,而且無法推斷出其分布函數的類型和參數,因此人們在學習模型時,往往根據根據經驗選擇較大的模型容量。
但模型容量過大時,搜索到的模型,可能由于表達能力過強,不僅學到了數據本省的模態,還學到了數據中的觀測誤差,這就會導致模型在訓練集上的表現很好,但在未見的新樣本上表現不佳,泛化能力弱,這種現象叫做模型的過擬合。當模型容量過小時,模型可能不能很好的學習到數據的模態,就會導致模型在訓練集上表現不佳,在未見過的新樣本上表現也很差,這種現象叫做欠擬合。
那么如何選擇合適的模型容量呢?統計學習理論給我們提供了一些思路,VC維是機器學習領域,一個比較通用的度量模型容量的方法。盡管這些方法給機器學習提供了一些理論保證,但在深度學習領域卻很難應用,一部分原因是神經網絡的機構復雜,很難確定網絡背后的數學模型的VC維度。
但是,我們可以根據奧卡姆剃刀原則,指導神經網絡的設計和訓練。“切勿浪費較多東西,去做‘用較少的東西,同樣可以做好的事 情’”。也就是說,如果兩層的神經網絡結構能夠很好的表達真實模型,那么三層的神經 網絡也能夠很好的表達,但是我們應該優先選擇使用更簡單的兩層神經網絡,因為它的參 數量更少,更容易訓練,也更容易通過較少的訓練樣本獲得不錯的泛化誤差。
6.2.1 欠擬合
欠擬合的原因,往往是模型容量不足,導致在假設空間內找不到一個合適的函數很好的擬合數據。表現是在訓練集上誤差很好,在測試集上的表現也很差。遇到這種情況,我們一般考慮增加模型的復雜度,增加數據維度等辦法處理。但由于以深度學習為代表的很多模型,可以輕易達到很深的維度,模型復雜度往往很高,所以欠擬合的問題一般不如過擬合的問題常見。
6.2.2 過擬合
現在說一下過擬合。當模型容量很大,可供搜索的假設空間也就會很大,這時候模型的表達能力過于強大,很可能會學習到訓練數據中的觀測誤差,導致在訓練集上的表現很好,但在測試集上的表現卻很差。這時候往往就是過擬合了。本章接下來的內容更多用于介紹如何避免過擬合。
6.3 數據集劃分
我們在做機器學習任務過程中,數據集要劃分為訓練集和測試集,但為了選擇模型超參數和檢測過擬合現象,往往再將訓練集劃分為訓練集和驗證集。也就是一個數據集會被劃分成訓練集、驗證集、測試集三部分。
6.3.1 驗證集與超參數
前邊已經介紹了訓練集和測試集,訓練集主要用來訓練模型,測試集主要用來驗證模型的泛化能力。測試集的樣本不能出現在訓練集中,防止模型學到測試集的信息,導致測試集不能真正反應模型的泛化能力,是一種有損模型泛化的行為。訓練集和測試集一般都采樣自同一分布的數據,對應比例可以根據情況調節。
但只將數據分為訓練集和測試集是不夠的,由于測試集不參與到模型訓練中,所以測試集不能用來作為模型訓練的實時反饋,而模型訓練過程中,我們需要挑選合適的參數模型,需要有個數據集對模型性能進行實時反饋,判斷模式是否過擬合。因此一般再將訓練集劃分為訓練集和驗證集。劃分后的訓練集主要用來訓練模型,驗證集主要用來進行超參數的選擇。
驗證集和測試集的主要區別在于,開發人員可以根據驗證集的反饋結果進行模型參數的調整,而訓練集一般只是用來驗證模型整體泛化能力。
6.3.2 提前停止(early stopping)
一般來說,把訓練集的一個Batch運算更新一次叫做一個step,對訓練集的所有樣本循環迭代一次叫做一個Epoch,整個訓練過程可能會進行多個Epoch。驗證集一般可以在間隔數次Step或數次Epoch之后,對模型進行驗證。如果驗證過于頻繁,雖然能清楚的記錄模型性能,但會帶來額外的計算消耗,所以一般建議間隔幾個Epoch進行一次驗證。
在訓練過程中,我們會同時關心訓練集和驗證集的誤差、準確率等指標。如果模型的訓練誤差 較低,訓練準確率較高,但是驗證誤差較高,驗證準確率較低,那么可能出現了過擬合現 象。如果訓練集和驗證集上面的誤差都較高,準確率較低,那么可能出現了欠擬合現象。
當發現模型過擬合時,可以通過重新設計模型容量,如減少網絡層數,添加正則化項等方式。
實際上,由于模型是隨著訓練不斷變化的,因此同一個模型可能會出現不同的過擬合、欠擬合。可以看到在訓練的前期,訓練集和測試集準確率都在不斷提升,沒有出現過擬合現象。但隨著訓練的持續,在某個Epoch出,會出現過擬合現象,具體表現如下圖所示,訓練集準確率不斷升高,而測試集準確率卻在不斷下降。
那么可不可以,在模型訓練到合適的Epoch時,就停止訓練,從而只過擬合現象的發生呢。我們可以通過觀察驗證集的準確率,找到合適的Epoch,當驗證集在連續幾個Epoch都沒有準確率的提升時,我們可以認為已經到了最合適的Epoch附近,從而提前挺尸訓練,避免訓練過度,發生過擬合。
6.4 模型設計
通過驗證集可以判斷網絡模型是否過擬合或欠擬合,從而為調整網絡模型的容量提供依據。對于神經網絡來說,網絡的層數和參數量時衡量網絡容量的重要參考指標。當網絡過擬合時,可以適當減少網絡層數或減少網絡層的參數量,從而降低網絡容量。反之如果發現模型欠擬合,擇可以加大模型容量。
6.5 正則化
通過不同層數和參數的網絡模型,可為優化算法提供初始的函數假設空間,但函數的假設空間時隨著訓練而不斷變化的。我們以多項式模型為例。
上述模型的容量,可以簡單的用n來衡量,但如果我們限制了 都為0的話,那么該模型的容量就變為k。因此可以通過限制網絡參數的稀疏性,限制網絡容量。
這種約束一般是在損失函數上添加額外的懲罰項。添加懲罰項之前的優化目標是:
添加懲罰項之后的優化目標是:
其中,懲罰項約束一般通過參數的范數來量化:
叫做參數
的l范數。
常用的正則化項有L0、L1、L2正則化。
6.5.1 L0正則化
對于L0正則化項,定義是所有中,非零元素的個數。但由于L0范數不可導,因此在神經網絡中使用的并不多。
6.5.2 L1正則化
L1 正則化也叫 Lasso Regularization。它是連續可導的,在神經網絡中使用廣泛。
w1 = tf.random.uniform([4, 3])
w2 = tf.random.uniform([4, 3])
loss_reg = tf.reduce_sum(tf.math.abs(w1) + tf.math.abs(w2))
6.5.3 L2正則化
L2 正則化也叫 Ridge Regularization,它和 L1 正則化一樣,也是連續可導的,在神經網絡中使用廣泛。
w1 = tf.random.uniform([4, 3])
w2 = tf.random.uniform([4, 3])
loss_reg = tf.reduce_sum(tf.square(w1) + tf.square(w2))
6.6 Dropout
Dropout是一種在神經網絡里經常用到的防止過擬合的方法。在訓練階段隨機斷開一部分神經網絡的連接,減少每次訓練時實際參與計算的參數量(如下圖右所示);但在測試階段,會恢復所有鏈接。
在TensorFlow中,可以通過增加dropout操作和添加dropout層來實現dropout。
# 增加dropout操作
x = tf.nn.dropout(x, rage=0.5)
# 增加dropout層
model.add(tf.keras.layers.Dropout(0.5))
6.7 數據增強
還有一種簡單直接的防止過擬合的方式,就是增加訓練數據。但實際上收集數據成本高昂,我們可以在已有的數據集上,通過數據增強,獲取更多的訓練數據。數據增強(Data Augmentation)是指在維持樣本標簽不變的前提下,根據先驗知識,改變樣本的特征,使得新生成的樣本也符合或近似符合數據的真實分布。
6.7.1 圖像
在圖像領域,可以通過對圖片進行旋轉、翻轉、裁剪等方式,將一張圖片衍生出多張圖片。
6.7.2 生成數據
如GAN網絡等,將在后邊章節介紹。
6.7.3 其他
可以對數據增加少量噪聲,同義詞替換,多次翻譯等方式。