1 生成對抗網絡概述
有時候我們希望網絡具有一定的創造力,比如畫畫、編曲等等,能否實現呢?是可以實現的,大家可以鑒別一下下面這幾張照片,哪些是真實的人臉,哪些是機器生成的人臉。很難判斷吧?本節最后會給出答案。
要實現上述能力,就要用到一種新的網絡架構— 生成對抗網絡(Generative Adversarial Net,GAN
) 。首先,我們大概來了解一下什么是 “生成” ,什么是 “對抗”。
8.1.1 對“生成”的理解
假設我們設計一個網絡,將其稱為 “生成器(Generator)”。生成器的輸入是一個向量,該向量一般是低維向量,它是通過一個特定的分布采樣出來的,例如正態分布。生成器的輸出是另一個向量
,該向量是一個高維向量,比如一個二次元的人臉。由于生成器的輸入向量是通過一個分布隨機采樣的,所以輸入向量每次都是不一樣的,因此生成器每次的輸出也是不一樣的,會形成一個復雜的分布。盡管輸出向量不一樣,但是我們要求這些輸出向量都是二次元的人臉,而不是其它。也就是說期望生成器輸出的復雜分布要和某個特定分布(例如所有二次元人臉的集合)盡可能相似,如何做到呢?這就要用到“對抗”。
1.2 對“對抗”的理解
我們常說要“感謝對手”,為什么呢?因為對手逼得我們不斷想辦法進步,最后讓我們進化成長為優秀的人。為了使生成網絡不斷進化以成為畫畫高手,我們還需要訓練另外一個網絡,叫做 “鑒別器(Discriminator)” 。鑒別器是專門用來和生成網絡進行對抗的,就是用它來逼得生成網絡不斷進化。鑒別器的輸入是一張圖片,它的輸出則是一個0-1的數字,數字越大就越認為這張圖片是一個二次元圖片,數字越小呢就越認為這張圖片不是一個二次元圖片。比如下圖中上面兩張圖片很清楚是二次元,所以鑒別器輸出1.0,而下面兩張圖片很模糊,所以鑒別器輸出0.1。因此,簡單點講,鑒別器的功能就是判斷某張圖片到底是不是二次元圖片。
現在我們把這個鑒別器拿過來和生成器進行對抗:
- ①版本1的生成器的參數是隨機生成的,所以其生成的圖片啥都不是。這 時候,我們對鑒別器進行訓練,以使鑒別器能夠鑒別出哪些是生成器生成的圖片,哪些是真實的二次元人臉。經過訓練后,我們得到了版本1的鑒別器。
- ②在版本1的鑒別器的基礎上,我們再來訓練生成器,訓練的目的是讓鑒別器分辨不出哪些是生成器生成的圖片,哪些是真實的二次元人臉。通過訓練之后,得到了版本2的生成器,此時生成的圖片有一點點像二次元了,足以騙過版本1的鑒別器。
- ③在版本2的生成器的基礎上,我們接著訓練鑒別器,同樣是要使鑒別器能夠鑒別出哪些是版本2生成器生成的圖片,哪些是真實的二次元人臉。通過訓練之后,得到了版本2的鑒別器。
- ④重復上述過程,不斷進化生成器和鑒別器,最后生成器可以生成非常逼真的二次元人臉。
通過上述過程我們可以看出,生成器和鑒別器在不斷的對抗過程中,兩者都在不斷的進步,可以說是對抗成就了對方。所以,它們亦敵亦友,相愛相殺,既對立又統一。
2 生成對抗網絡的理論基礎
我們剛才提到生成器的輸入是由一個簡單的分布(如正態分布)采樣得到的一堆向量,輸出是一堆向量構成另一個一個復雜的分布,用表示。我們期望
和某個特定的分布盡可能地相似,而這個分布來自于一堆真實的數據,這個分布表示為
。如果我們用
來表示這兩個分布的Divergence(這個英文不好翻譯,暫且理解為“差異程度”吧),那么我們的目標就是尋找一個生成器
要使
最小,即,
我們知道在機器學習中,訓練的目標是要使損失函數最小,所以在該任務中損失函數就是
。但是有一個很關鍵的問題,我們如何計算這兩個分布的Divergence呢?好像沒法用解析式去描述這兩個分布的Divergence,那怎么辦呢?我們可以通過采樣的方式來計算這兩個分布的Divergence。
采樣是很好辦的,以二次元人臉生成器為例。
假設從
最重要的一點是,
我們可以從直觀上來理解為什么
既然我們已經知道
對了,本節最前面的人臉全部是由機器生成的,驚嘆吧!?