CAM系列(一)之CAM(原理講解和PyTorch代碼實現)

本文首發自【簡書】作者【西北小生_】的博客,轉載請私聊作者!


圖1 CAM實現示意圖

一、什么是CAM?

CAM的全稱是Class Activation MappingClass Activation Map,即類激活映射類激活圖。

論文《Learning Deep Features for Discriminative Localization》發現了CNN分類模型的一個有趣的現象:
CNN的最后一層卷積輸出的特征圖,對其通道進行加權疊加后,其激活值(ReLU激活后的非零值)所在的區域,即為圖像中的物體所在區域。而將這一疊加后的單通道特征圖覆蓋到輸入圖像上,即可高亮圖像中物體所在位置區域。如圖1中的輸入圖像和輸出圖像所示。

該文章作者將實現這一現象的方法命名為類激活映射,并將特征圖疊加在原始輸入圖像上生成的新圖片命名為類激活圖。

二、CAM有什么用?

CAM一般有兩種用途:

  • 可視化模型特征圖,以便觀察模型是通過圖像中的哪些區域特征來區分物體類別的;
  • 利用卷積神經網絡分類模型進行弱監督的圖像目標定位。

第一種用途是最直接的用途,根據CAM高亮的圖像區域,可以直觀地解釋CNN是如何區分不同類別的物體的。

對于第二種用途,一般的目標定位方法,都需要專門對圖像中的物體位置區域進行標注,并將標注信息作為圖像標簽的一部分,然后通過訓練帶標簽的圖像和專門的目標定位模型才能實現定位,是一種強監督的方法。而CAM方法不需要物體在圖像中的位置信息,僅僅依靠圖像整體的類別標簽訓練分類模型,即可找到圖像中物體所在的大致位置并高亮之,因此可以作為一種弱監督的目標定位方法。

三、CAM原理

圖2 輸出結構示意圖

如圖2所示,CNN最后一層卷積層輸出的特征圖是三維的:[C, H, W ],設特征圖的第k個通道可表示為f_k(x,y),其中x,y分別是寬和高維度上的索引。若最后一個卷積層連接一個全局平均池化層,然后再由一個全連接層輸出分類結果,則由最后一個卷積層的輸出特征圖到輸出層中的第c個類別的置信分數(未進行Softmax映射前)的計算過程可表示為:
S_c=\sum_{k}w_{k}^{c} \sum_{x,y}f_{k}(x,y)=\sum_{x,y} \sum_{k} w_{k}^{c} f_{k}(x,y) \tag{1}
其中\sum_{x,y}f_{k}(x,y)為全局平均池化(省略了除以元素總數),由于只對空間上到寬和高兩個維度求和,結果就是這兩個維度坍塌,只剩通道維度保持不變,即計算結果為C個數值,每個值代表著該通道上所有值的平均值。w_{k}^{c}表示全連接輸出層中第c類對應的C個權重中的第k個:即全連接層的權重矩陣W[N_o,C]維的(N_o即輸出類別數,C是最后一層卷積層的輸出通道數),那么第c類對應的權重w^c就應該是W[c,:]w^c有著C個權重參數,對應著每個輸入值(即全局平均池化的結果),w^c_k就是這C個權重參數中的第k個數。

\sum_{k}w_{k}^{c} \sum_{x,y}f_{k}(x,y)表示特征圖的每個輸出通道首先被平均為一個值,C個通道得到C個值,然后這些值再被加權相加得到一個數,這個數就是第c類的置信分數,表征著輸入圖像的類別是c的可能性大小。

\sum_{x,y} \sum_{k} w_{k}^{c} f_{k}(x,y)表示首先對特征圖的每個通道進行加權求和(\sum_{k} w_{k}^{c} f_{k}(x,y)),得到一個二維的特征圖(通道維坍塌),然后再對這個二維特征圖求平均值,得到第c類的置信分數。

由公式(1)的推導可知,先對特征圖進行全局平均池化,再進行加權求和得到類別的置信分數,等價于先對特征圖進行通道維度的加權求和,再進行全局平均池化。

經過這一等價變換,就突顯了特征圖通道加權和\sum_{k} w_{k}^{c} f_{k}(x,y)的重要性了:一方面,特征圖的通道加權和直接編碼了類別信息;另一方面,也是最重要的,特征圖的通道加權和是二維的,還保留著圖像的空間位置信息。我們可以通過可視化方法觀察到圖像中的相對位置信息與CNN編碼的類別信息的關系。

這里的特征圖的通道加權之和\sum_{k} w_{k}^{c} f_{k}(x,y)就叫做類別激活圖。

四、CAM的PyTorch實現

本文以PyTorch自帶的ResNet-18為例,分步驟講解并用代碼實現CAM的整個流程和細節。

1.準備工作

首先導入需要用到的包:

import math
import torch
from torch import Tensor
from torch import nn
import torch.nn.functional as F
from typing import Optional, List
import torchvision.transforms as transforms
from PIL import Image
import torchvision.models as models
from torch import Tensor
from matplotlib import cm
from torchvision.transforms.functional import to_pil_image

定義輸入圖片路徑,和保存輸出的類激活圖的路徑:

img_path = '/home/dell/img/1.JPEG'     # 輸入圖片的路徑
save_path = '/home/dell/cam/CAM1.png'    # 類激活圖保存路徑

定義輸入圖片預處理方式。由于本文用的輸入圖片來自ILSVRC-2012驗證集,因此采用PyTorch官方文檔提供的ImageNet驗證集處理流程:

preprocess = transforms.Compose([transforms.Resize(256),
                                transforms.CenterCrop(224),
                                transforms.ToTensor(),
                                transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])])
2.獲取CNN最后一層卷積層的輸出特征圖

本文選用的CNN模型是PyTorch自帶的ResNet-18,首先導入預訓練模型:

net = models.resnet18(pretrained=True).cuda()   # 導入模型

由于特征圖是模型前向傳播時的中間變量,不能直接從模型中獲取,需要用到PyTorch提供的hook工具,補課請參考我的這兩篇博客:hook1hook2。

通過輸出模型(print(net))我們就能看到ResNet-18輸出最后一層特征圖的層為net.layer4(或者net.layer4[1]、net.layer4[1].bn2都可)。我們用hook工具注冊這一層,以便獲得它的輸出特征圖:

feature_map = []     # 建立列表容器,用于盛放輸出特征圖

def forward_hook(module, inp, outp):     # 定義hook
    feature_map.append(outp)    # 把輸出裝入字典feature_map

net.layer4.register_forward_hook(forward_hook)    # 對net.layer4這一層注冊前向傳播

做好了hook的定義和注冊工作,現在只需要對輸入圖片進行預處理,然后執行一次模型前向傳播即可獲得CNN最后一層卷積層的輸出特征圖:

orign_img = Image.open(img_path).convert('RGB')    # 打開圖片并轉換為RGB模型
img = preprocess(orign_img)     # 圖片預處理
img = torch.unsqueeze(img, 0)     # 增加batch維度 [1, 3, 224, 224]

with torch.no_grad():
    out = net(img.cuda())     # 前向傳播

這時我們想要的特征圖已經裝在列表feature_map中了。我們輸出尺寸來驗證一下:

In [10]: print(feature_map[0].size())
torch.Size([1, 512, 7, 7])
3.獲取權重

CAM使用的權重是全連接輸出層中,對應這張圖像所屬類別的權重。文字表述可能存在歧義或不清楚,直接看本文最上面的圖中全連接層被著色的連接??梢钥吹?,每個連接對應一個權重值,左邊和特征圖的每個通道(全局平均池化后)一一連接,右邊全都連接著輸出類別所對應的那個神經元。

由于我也不知道這張圖的類別標簽,這里假設模型對這張圖像分類正確,我們來獲得其輸出類別所對應的權重:

cls = torch.argmax(out).item()    # 獲取預測類別編碼
weights = net._modules.get('fc').weight.data[cls,:]    # 獲取類別對應的權重
4.對特征圖的通道進行加權疊加,獲得CAM
cam = (weights.view(*weights.shape, 1, 1) * feature_map[0].squeeze(0)).sum(0)

這里的代碼比較簡單,擴充權重的維度([512, ]\rightarrow[512, 1, 1])是為了使之在通道上與特征圖相乘;去除特征圖的batch維([1, 512, 7, 7]\rightarrow[512, 7, 7])是為了使其維度和weights擴充后的維度相同以相乘。最后在第一維(通道維)上相加求和,得到一個7\times 7的類激活圖。

5.對CAM進行ReLU激活和歸一化

這一步有兩個細節需要注意:

  • 上步得到的類激活圖像素值分布雜亂,要想確定目標位置,須先進行ReLU激活,將正值保留,負值置零。像素值正值所在的(一個或多個)區域即為目標定位區域。
  • 上步獲得的激活圖還只是一個普通矩陣,需要變換成圖像規格,將其值歸一化到[0,1]之間。

我們首先定義歸一化函數:

def _normalize(cams: Tensor) -> Tensor:
        """CAM normalization"""
        cams.sub_(cams.flatten(start_dim=-2).min(-1).values.unsqueeze(-1).unsqueeze(-1))
        cams.div_(cams.flatten(start_dim=-2).max(-1).values.unsqueeze(-1).unsqueeze(-1))

        return cams

然后對類激活圖執行ReLU激活和歸一化,并利用PyTorch的 to_pil_image函數將其轉換為PIL格式以便下步處理:

cam = _normalize(F.relu(cam, inplace=True)).cpu()
mask = to_pil_image(cam.detach().numpy(), mode='F')

將類激活圖轉換成PIL格式是為了方便下一步和輸入圖像融合,因為本例中我們選用的PIL庫將輸入圖像打開,選用PIL庫也是因為PyTorch處理圖像時默認的圖像格式是PIL格式的。

6.將類激活圖覆蓋到輸入圖像上,實現目標定位

這一步也有很多細節需要注意:

  • 上步得到的類激活圖只有7\times 7的尺寸,想要將其覆蓋在輸入圖像上顯示,就需將其用插值的方法擴大到和輸入圖像相同大小。
  • 我們的目的是用類激活圖中被激活(非零值)的位置區域,來高亮原始圖像中相應的位置區域,這一高亮的方法就是將激活圖變換為熱力圖的形式:值越大的像素顏色越紅,值越小的像素顏色越藍。
  • 如果直接將熱力圖覆蓋到原始輸入圖像上,會遮蔽圖像中的內容導致不容易觀察,因此需要設置兩個圖像融合的比例(透明度),即在兩種圖像融合在一起時,將原始輸入圖像的像素值權重設置大一些,而把熱力圖的像素值權重設置小一些,這樣就會使生成圖像中原始輸入圖像的內容更加清晰,易于觀察。(mixup方法同理)
  • 兩種圖像融合后的像素值會超出圖像規格像素值的范圍[0,1],因此還需要將其轉換為圖像規格。

我們將兩個圖像交疊融合的過程封裝成了函數:

def overlay_mask(img: Image.Image, mask: Image.Image, colormap: str = 'jet', alpha: float = 0.6) -> Image.Image:
    """Overlay a colormapped mask on a background image

    Args:
        img: background image
        mask: mask to be overlayed in grayscale
        colormap: colormap to be applied on the mask
        alpha: transparency of the background image

    Returns:
        overlayed image
    """

    if not isinstance(img, Image.Image) or not isinstance(mask, Image.Image):
        raise TypeError('img and mask arguments need to be PIL.Image')

    if not isinstance(alpha, float) or alpha < 0 or alpha >= 1:
        raise ValueError('alpha argument is expected to be of type float between 0 and 1')

    cmap = cm.get_cmap(colormap)    
    # Resize mask and apply colormap
    overlay = mask.resize(img.size, resample=Image.BICUBIC)
    overlay = (255 * cmap(np.asarray(overlay) ** 2)[:, :, 1:]).astype(np.uint8)
    # Overlay the image with the mask
    overlayed_img = Image.fromarray((alpha * np.asarray(img) + (1 - alpha) * overlay).astype(np.uint8))

    return overlayed_img

接下來就是激動人心的時刻了?。?!將類激活圖作為掩碼,以一定的比例覆蓋到原始輸入圖像上,生成類激活圖:

result = overlay_mask(orign_img, mask) 

這里的變量result已經是有著PIL圖片格式的類激活圖了,我們可以通過:

result.show()

可視化輸出,也可以通過:

result.save(save_path)

將圖片保存在本地查看。我們在這里展示一下輸入圖像和輸出定位圖像的對比:


(左)輸入圖像;(右)定位圖像
最后編輯于
?著作權歸作者所有,轉載或內容合作請聯系作者
  • 序言:七十年代末,一起剝皮案震驚了整個濱河市,隨后出現的幾起案子,更是在濱河造成了極大的恐慌,老刑警劉巖,帶你破解...
    沈念sama閱讀 227,702評論 6 531
  • 序言:濱河連續發生了三起死亡事件,死亡現場離奇詭異,居然都是意外死亡,警方通過查閱死者的電腦和手機,發現死者居然都...
    沈念sama閱讀 98,143評論 3 415
  • 文/潘曉璐 我一進店門,熙熙樓的掌柜王于貴愁眉苦臉地迎上來,“玉大人,你說我怎么就攤上這事。” “怎么了?”我有些...
    開封第一講書人閱讀 175,553評論 0 373
  • 文/不壞的土叔 我叫張陵,是天一觀的道長。 經常有香客問我,道長,這世上最難降的妖魔是什么? 我笑而不...
    開封第一講書人閱讀 62,620評論 1 307
  • 正文 為了忘掉前任,我火速辦了婚禮,結果婚禮上,老公的妹妹穿的比我還像新娘。我一直安慰自己,他們只是感情好,可當我...
    茶點故事閱讀 71,416評論 6 405
  • 文/花漫 我一把揭開白布。 她就那樣靜靜地躺著,像睡著了一般。 火紅的嫁衣襯著肌膚如雪。 梳的紋絲不亂的頭發上,一...
    開封第一講書人閱讀 54,940評論 1 321
  • 那天,我揣著相機與錄音,去河邊找鬼。 笑死,一個胖子當著我的面吹牛,可吹牛的內容都是我干的。 我是一名探鬼主播,決...
    沈念sama閱讀 43,024評論 3 440
  • 文/蒼蘭香墨 我猛地睜開眼,長吁一口氣:“原來是場噩夢啊……” “哼!你這毒婦竟也來了?” 一聲冷哼從身側響起,我...
    開封第一講書人閱讀 42,170評論 0 287
  • 序言:老撾萬榮一對情侶失蹤,失蹤者是張志新(化名)和其女友劉穎,沒想到半個月后,有當地人在樹林里發現了一具尸體,經...
    沈念sama閱讀 48,709評論 1 333
  • 正文 獨居荒郊野嶺守林人離奇死亡,尸身上長有42處帶血的膿包…… 初始之章·張勛 以下內容為張勛視角 年9月15日...
    茶點故事閱讀 40,597評論 3 354
  • 正文 我和宋清朗相戀三年,在試婚紗的時候發現自己被綠了。 大學時的朋友給我發了我未婚夫和他白月光在一起吃飯的照片。...
    茶點故事閱讀 42,784評論 1 369
  • 序言:一個原本活蹦亂跳的男人離奇死亡,死狀恐怖,靈堂內的尸體忽然破棺而出,到底是詐尸還是另有隱情,我是刑警寧澤,帶...
    沈念sama閱讀 38,291評論 5 357
  • 正文 年R本政府宣布,位于F島的核電站,受9級特大地震影響,放射性物質發生泄漏。R本人自食惡果不足惜,卻給世界環境...
    茶點故事閱讀 44,029評論 3 347
  • 文/蒙蒙 一、第九天 我趴在偏房一處隱蔽的房頂上張望。 院中可真熱鬧,春花似錦、人聲如沸。這莊子的主人今日做“春日...
    開封第一講書人閱讀 34,407評論 0 25
  • 文/蒼蘭香墨 我抬頭看了看天上的太陽。三九已至,卻和暖如春,著一層夾襖步出監牢的瞬間,已是汗流浹背。 一陣腳步聲響...
    開封第一講書人閱讀 35,663評論 1 280
  • 我被黑心中介騙來泰國打工, 沒想到剛下飛機就差點兒被人妖公主榨干…… 1. 我叫王不留,地道東北人。 一個月前我還...
    沈念sama閱讀 51,403評論 3 390
  • 正文 我出身青樓,卻偏偏與公主長得像,于是被迫代替她去往敵國和親。 傳聞我的和親對象是個殘疾皇子,可洞房花燭夜當晚...
    茶點故事閱讀 47,746評論 2 370

推薦閱讀更多精彩內容