73FAST.AI 深度學習實踐課程--FAST.AI 圖像分類實踐

FAST.AI 圖像分類實踐

計算機視覺是深度學習中最常見的應用領域,其中主要有:圖像分類、圖像生成,對象檢測、目標跟蹤、語義分割、實例分割等。FAST.AI 作為一款基于 PyTorch 開發的快速深度學習工具,自然也就包含有大量更便利的圖像處理模塊和方法。接下來,我們將以最常見的圖像分類為例,使用 FAST.AI 進行實踐。

圖像分類

圖像分類是最為常見的一項深度學習任務,一般情況下,完成該類任務會有 3 個重要步驟。

image.png

首先,我們需要對原始數據進行處理,將圖像數據轉換為深度學習工具能夠支持的張量數據。這一步驟往往就是制作相應的數據加載器。當然,FAST.AI 也有自己對應的數據加載器 DataBunch 對象,這部分內容已在前面章節完成學習。
接下來,就是構建深度神經網絡模型。圖像處理相關的任務,大部分都會使用卷積神經網絡模型。卷積神經網絡是一種非常擅長解決計算機視覺任務的神經網絡模型。當然,無論是 PyTorch,還是 TensorFlow,構建一個神經網絡模型的難度不高,我們往往只需要調用相應深度神經網絡框架完成層堆疊即可。
最后,就是神經網絡訓練的部分。這部分代碼一般是最為復雜的,我們需要對數據進行適當地處理,以正確的方式輸入到神經網絡。最后,對神經網絡的輸出進行處理和評估。神經網絡訓練的部分需要有一定的構建經驗才能完成,尤其是在 PyTorch 的應用過程中,相對于 TensorFlow 更為復雜。
FAST.AI 基于 PyTorch 開發,實際上我認為其最大的改進之處就是優化了 PyTorch 訓練神經網絡復雜的過程。接下來,我們將通過一個圖像分類示例,來學習使用 FAST.AI 完成一個完整的圖像分類任務。
數據處理
接下來,我們選擇前面接觸過的 MNIST 數據集進行演示,MNIST 是一個 10 個類別的圖像分類任務,數據體積較小,非常適合作為工具使用方法的示例數據。首先,我們加載數據,并構建 DataBunch 對象,這部分內容實際上已經學習過了。

from fastai.datasets import untar_data, URLs, download_data
from fastai.vision import ImageDataBunch

# 因原數據集下載較慢,從藍橋云課服務器下載數據,本次實驗時無需此行代碼
download_data("https://labfile.oss.aliyuncs.com/courses/1445/mnist_png")
mnist_path = untar_data(URLs.MNIST)
mnist_data = ImageDataBunch.from_folder(mnist_path, 'training', 'testing')
mnist_data

模型構建
構建完數據加載器 DataBunch 之后,接下來就可以開始構建模型了。一般情況下,構建一個圖像分類模型有 2 種思路,分別是從頭構建和遷移學習。從頭構建,即意味著由你自己設計模型的結構和參數。而遷移學習則是利用一些在經典神經網絡結構上預訓練的模型進行學習。

首先,我們選擇從頭構建模型。FAST.AI 提供了一個非常友好的接口 fastai.vision.simple_cnn 來快速實現卷積神經網絡的構建。該 API 包含 4 個參數:
actns:定義卷積模塊的數量和輸入輸出大小。
kernel_szs:定義卷積核大小,默認為 3。
strides:定義卷積步長大小,默認為 2。
bn:是否包含批量歸一化操作,布爾類型。

接下來,我們就調用該接口來快速定義一個卷積神經網絡。

from fastai.vision import simple_cnn

model = simple_cnn(actns=(3, 16, 16, 10))
model

如上所示,我們只是定義了卷積神經網絡包含的卷積模塊數量和輸入輸出大小。該參數主要注意輸入和輸出尺寸,其中,(3, 16, 16, 10) 表示有 4 個卷積模塊。因為前面的 DataBunch 對象尺寸為 (3, 28, 28),即為 3 個通道圖像,所以第一層卷積操作的輸入尺寸為 3。由于是 10 分類問題,所以最后一個數字是 10。中間層的尺寸可以自定義,我們選擇了 16。fastai.vision.simple_cnn 最終會自動構建為 PyTorch 支持的 Sequential 順序模型。
訓練評估
有了模型之后,我們就可以開始第三步,也就是訓練過程。FAST.AI 的模型訓練過程會用到其核心類 fastai.vision.Learner。最簡單的情況下,我們只需要將數據 DataBunch,模型和評估指標傳入,即可開始訓練。

from fastai.vision import Learner, accuracy

# 傳入數據,模型和準確度評估指標
learner = Learner(mnist_data, model, metrics=[accuracy])
learner

如上所示,我們定義的 Learner 選擇了 accuracy 準確度作為評估指標。你可以通過 Learner 的輸出看到其他相關的默認參數設置。例如優化器 opt_func 選擇了 Adam,損失函數 loss_func 選擇了交叉熵。
加下來,我們可以調用 Learner 完成最終的訓練,訓練方法為 Learner.fit,傳入迭代次數 Epoch 即可。

learner.fit(1)  # 數據集上訓練迭代 1 次

最終,Learner 會打印出最終的訓練損失,驗證損失,準確度和訓練所用時長。至此,我們就使用 FAST.AI 完成了一次針對 MNIST 的圖像分類過程。我們整理上面的完整代碼如下:

mnist_path = untar_data(URLs.MNIST)
mnist_data = ImageDataBunch.from_folder(mnist_path, 'training', 'testing')
model = simple_cnn(actns=(3, 16, 16, 10))

learner = Learner(mnist_data, model, metrics=[accuracy])
learner.fit(1)

你可以看出,使用 FAST.AI 完成 MNSIT 分類我們只使用了 5 行代碼,而相比于 PyTorch 需要的數十行代碼和復雜的構建過程,FAST.AI 中的 FAST 是顯而易見的。

遷移學習

上面,我們從頭構建了一個卷積神經網絡并針對 MNSIT 進行了訓練。由于 MNSIT 數據集本身就質量較高,背景純凈,數據規范,所以最終準確度還是不錯的。如果你將上方 Learner 訓練迭代次數調至 3~5 次,準確度還會有一定的提升,并最終超過 90%。但對于一些復雜的任務,尤其是樣本數據不規范的情況下,從頭開始訓練并不是一個很明智的選擇。所以,很多時候我們會使用預訓練模型做遷移學習。
遷移學習是一種站在巨人的肩膀上的訓練方法。我們可以沿用一些經典神經網絡在大型數據集訓練好的模型,使用自定義數據集繼續更新其中部分層的權重。最終,可以在較少的時間下取得不錯的訓練效果。
FAST.AI 提供的預訓練模型大部分直接來自于 PyTorch,你可以通過 此頁面 瀏覽這些模型。接下來,我們以 ResNet18 為例,針對上面的 MNIST 數據完成一次遷移學習過程。ResNet18 是 ResNet 精簡結構在 ImageNet 數據集上得到的預訓練模型,首先載入該模型并查看結構。

from fastai.vision import models

models.resnet18()

可以看出,相比于我們之前自行搭建的 CNN 結構,ResNet18 要復雜很多。解析來的訓練過程需要利用 fastai.vision.cnn_learner 類來構建 Learner,這一點也與上面有所不同。你只需要記住,如果是從頭開始就使用 fastai.vision.Learner,如果是遷移學習就使用 fastai.vision.cnn_learner 即可。

from fastai.vision import cnn_learner

# 構建基于 ResNet18 的 Learner 學習器
learner = cnn_learner(mnist_data, models.resnet18, metrics=[accuracy])
learner.fit(1)  # 訓練迭代 1 次

你可以看到 Learner 會自動下載 ResNet18 的 .pth 預訓練權重文件,然后開始訓練迭代過程。訓練過程相對于上方會更長一些,原因是模型復雜度更高。最終,使用 ResNet18 完成 1 次迭代的準確度,應該會比上方我們自定義的模型高一些。
所以,當我們使用 FAST.AI 執行遷移學習時,代碼可以進一步精簡至 4 行。這對于使用 PyTorch 和 TensorFlow 是不可想象的簡單,也體現了高階 API 的優勢。

mnist_path = untar_data(URLs.MNIST)
mnist_data = ImageDataBunch.from_folder(mnist_path, 'training', 'testing')

learner = Learner(mnist_data, models.resnet18, metrics=[accuracy])
learner.fit(1)

雖然我們的準確度已經達到了 90% 以上,但模型仍然對部分驗證數據無法準確區分。接下來,我們可以通過 FAST.AI 提供的 fastai.vision.ClassificationInterpretation 方法來對結果進行進一步分析。

from fastai.vision import ClassificationInterpretation

# 載入學習器
interp = ClassificationInterpretation.from_learner(learner)
interp

首先,我們可以輸出那些被分類器預測錯誤的樣本進行觀察。直接通過 interp.plot_top_losses 方法輸出損失最大的 9 個驗證樣本,并比對它們本來的標簽和預測結果。

interp.plot_top_losses(9, figsize=(9, 9))

上面依次輸出了預測標簽,真實標簽,損失和預測概率。你可以看到,部分樣本的確人眼都很難完成辨識,當然也有一些人眼可辨識樣本被錯誤分類。
除了比對圖像,FAST.AI 還提供了一個非常方便的方法 interp.plot_confusion_matrix。通過該方法,我們可以直接繪制出真實標簽和預測標簽之間的混淆矩陣。

interp.plot_confusion_matrix(figsize=(5, 5), dpi=100)

混淆矩陣展示了真實標簽和預測標簽對應樣本的數量。可以看出,0-9 這 10 類樣本在分布上沒有明顯的傾斜。你也可以進一步看出,到底哪些樣本更容易被預測錯誤,以及被錯誤預測的標簽結果。

數據擴增

數據在神經網絡訓練過程中伴有很大的左右,如果符合要求的數據越多,往往訓練的結果也更好。所以,很多時候我們會對現有數據進行一些旋轉、變換、鏡像、歸一化等操作。這些操作不僅可以在一定程度上起到數據擴增的效果,能夠對模型訓練帶來一些幫助。

FAST.AI 提供了一個非常方便的函數 fastai.vision.get_transforms 用于對圖像進行變換,該函數的主要參數有:
do_flip:如果為 True,則以 0.5 的概率應用隨機翻轉。
flip_vert:應用水平翻轉。如果 do_flip=True 時,則可以垂直翻轉圖像或旋轉 90 度。
max_rotate:如果不為 None,則在 -max_rotatemax_rotate 度之間隨機旋轉,概率為 p_affine
max_zoom:如果不是 1 或小于 1,則在 1 之前進行隨機縮放,并以 p_affine 概率應用 max_zoom
max_lighting:如果不為 None,則以 max_lighting 概率 p_lighting 施加由 max_lighting 控制的隨機噪聲和對比度變化。
max_warp:如果不是 None,則以概率 p_affine 施加 -max_warpmaw_warp 之間的隨機對稱扭曲。
p_affine:應用每個仿射變換和對稱扭曲的概率。
p_lighting:應用每個照明變換的概率。
xtra_tfms:您想要應用的其他變換的列表。

接下來,通過一個直觀的例子來演示數據變換擴增的效果。我們讀取 MNIST 訓練數據中第一個樣本:

img, label = mnist_data.train_ds[0]
img.show(title=f'{label}')

然后,我們嘗試對該數據進行隨機旋轉變換操作。為了更加方便地演示旋轉后的效果,這里定義一個輔助繪圖函數 plots_f

from fastai.vision import get_transforms
from matplotlib import pyplot as plt
%matplotlib inline

# 輔助繪圖函數,參考自 FAST.AI 官方文檔
def plots_f(rows, cols, width, height, **kwargs):
    [img.apply_tfms(tfms[0], **kwargs).show(ax=ax) for i, ax in enumerate(plt.subplots(
        rows, cols, figsize=(width, height))[1].flatten())]

接下來,定義變換操作并應用繪圖。

# 定義變換操作,最大 [-25, 25] 度之間的隨機旋轉
tfms = get_transforms(max_rotate=25)
# 繪制樣本變換后圖像
plots_f(2, 4, 12, 6, size=224)

可以看到,樣本被執行了 -25 度到 25 度之間的隨機旋轉操作。不過,上面的示例有一定的缺陷。因為對于手寫字符,較大幅度的旋轉或鏡像圖像會嚴重影響樣本所反映的內容,甚至變成完全不是數字的樣子。所以,對于 MNIST 這類數據,我們往往只能應用小幅度的旋轉、添加噪聲等變換,以避免對樣本本身含義的影響。但是,對于如下所示的動物圖像,更大幅度的變換對數據集擴增更有意義。

image.png

fastai.vision.get_transforms 操作一般會直接添加至 DataBunch 對象生成過程中,這樣就可以將變換操作應用于樣本數據。

# 示例,制作 DataBunch 對象時添加 get_transforms 操作
tfms = get_transforms(max_rotate=25)
tfms_data = ImageDataBunch.from_folder(mnist_path, 'training', 'testing', ds_tfms=tfms)
tfms_data.show_batch(rows=3, figsize=(5,5))

CIFAR10 圖像分類挑戰

前面的挑戰中,我們已經熟悉了 CIFAR10 數據集,并將其處理成 FAST.AI 支持的 DataBunch 對象。本次挑戰中,我們同樣需讀取 CIFAR10 數據集,并添加針對數據集變換的預處理過程。
接下來,請將 CIFAR10 數據集處理成 DataBunch 對象。挑戰要求,將 train 文件夾中數據分離 20% 作為驗證集,剩下數據作為訓練集。test 文件夾下數據作為測試集。同時,加入 get_transforms 變換,應用[?30,30] 度之間的隨機旋轉變換。

from fastai.datasets import untar_data, URLs, download_data
from fastai.vision import ImageDataBunch, get_transforms

download_data("http://labfile.oss.aliyuncs.com/courses/1445/cifar10")
data_path = untar_data(URLs.CIFAR)
# 針對數據集變換
tfms = get_transforms(max_rotate=30)
data_bunch = ImageDataBunch.from_folder(data_path, train='train', test='test',
                                        valid_pct=0.2, ds_tfms=tfms)

接下來,請使用 FAST.AI 提供的建模方法,應用卷積神經網絡對 CIFAR10 完成分類和評估。你可以自由選擇「從頭開始訓練」或「遷移學習方法」。遷移學習所使用的預訓練模型也可以通過閱讀官方文檔自由選擇。
挑戰最終要求,驗證集上的分類準確度不得低于 70%。由于訓練時間較長,你可以在恰當的時候中止訓練。

from fastai.vision import models, cnn_learner, accuracy
models.resnet18()
# 構建基于 ResNet18 的 Learner 學習器
learner = cnn_learner(data_bunch, models.resnet18, metrics=[accuracy])
learner.fit(15)  # 訓練迭代 15 次

僅供參考,accuracy 最終大于 70% 即可。

?著作權歸作者所有,轉載或內容合作請聯系作者
平臺聲明:文章內容(如有圖片或視頻亦包括在內)由作者上傳并發布,文章內容僅代表作者本人觀點,簡書系信息發布平臺,僅提供信息存儲服務。
  • 序言:七十年代末,一起剝皮案震驚了整個濱河市,隨后出現的幾起案子,更是在濱河造成了極大的恐慌,老刑警劉巖,帶你破解...
    沈念sama閱讀 227,797評論 6 531
  • 序言:濱河連續發生了三起死亡事件,死亡現場離奇詭異,居然都是意外死亡,警方通過查閱死者的電腦和手機,發現死者居然都...
    沈念sama閱讀 98,179評論 3 414
  • 文/潘曉璐 我一進店門,熙熙樓的掌柜王于貴愁眉苦臉地迎上來,“玉大人,你說我怎么就攤上這事。” “怎么了?”我有些...
    開封第一講書人閱讀 175,628評論 0 373
  • 文/不壞的土叔 我叫張陵,是天一觀的道長。 經常有香客問我,道長,這世上最難降的妖魔是什么? 我笑而不...
    開封第一講書人閱讀 62,642評論 1 309
  • 正文 為了忘掉前任,我火速辦了婚禮,結果婚禮上,老公的妹妹穿的比我還像新娘。我一直安慰自己,他們只是感情好,可當我...
    茶點故事閱讀 71,444評論 6 405
  • 文/花漫 我一把揭開白布。 她就那樣靜靜地躺著,像睡著了一般。 火紅的嫁衣襯著肌膚如雪。 梳的紋絲不亂的頭發上,一...
    開封第一講書人閱讀 54,948評論 1 321
  • 那天,我揣著相機與錄音,去河邊找鬼。 笑死,一個胖子當著我的面吹牛,可吹牛的內容都是我干的。 我是一名探鬼主播,決...
    沈念sama閱讀 43,040評論 3 440
  • 文/蒼蘭香墨 我猛地睜開眼,長吁一口氣:“原來是場噩夢啊……” “哼!你這毒婦竟也來了?” 一聲冷哼從身側響起,我...
    開封第一講書人閱讀 42,185評論 0 287
  • 序言:老撾萬榮一對情侶失蹤,失蹤者是張志新(化名)和其女友劉穎,沒想到半個月后,有當地人在樹林里發現了一具尸體,經...
    沈念sama閱讀 48,717評論 1 333
  • 正文 獨居荒郊野嶺守林人離奇死亡,尸身上長有42處帶血的膿包…… 初始之章·張勛 以下內容為張勛視角 年9月15日...
    茶點故事閱讀 40,602評論 3 354
  • 正文 我和宋清朗相戀三年,在試婚紗的時候發現自己被綠了。 大學時的朋友給我發了我未婚夫和他白月光在一起吃飯的照片。...
    茶點故事閱讀 42,794評論 1 369
  • 序言:一個原本活蹦亂跳的男人離奇死亡,死狀恐怖,靈堂內的尸體忽然破棺而出,到底是詐尸還是另有隱情,我是刑警寧澤,帶...
    沈念sama閱讀 38,316評論 5 358
  • 正文 年R本政府宣布,位于F島的核電站,受9級特大地震影響,放射性物質發生泄漏。R本人自食惡果不足惜,卻給世界環境...
    茶點故事閱讀 44,045評論 3 347
  • 文/蒙蒙 一、第九天 我趴在偏房一處隱蔽的房頂上張望。 院中可真熱鬧,春花似錦、人聲如沸。這莊子的主人今日做“春日...
    開封第一講書人閱讀 34,418評論 0 26
  • 文/蒼蘭香墨 我抬頭看了看天上的太陽。三九已至,卻和暖如春,著一層夾襖步出監牢的瞬間,已是汗流浹背。 一陣腳步聲響...
    開封第一講書人閱讀 35,671評論 1 281
  • 我被黑心中介騙來泰國打工, 沒想到剛下飛機就差點兒被人妖公主榨干…… 1. 我叫王不留,地道東北人。 一個月前我還...
    沈念sama閱讀 51,414評論 3 390
  • 正文 我出身青樓,卻偏偏與公主長得像,于是被迫代替她去往敵國和親。 傳聞我的和親對象是個殘疾皇子,可洞房花燭夜當晚...
    茶點故事閱讀 47,750評論 2 370

推薦閱讀更多精彩內容