本文首發自【簡書】作者【西北小生_】的博客,轉載請私聊作者!
一、什么是CAM?
CAM的全稱是Class Activation Mapping或Class Activation Map,即類激活映射或類激活圖。
論文《Learning Deep Features for Discriminative Localization》發現了CNN分類模型的一個有趣的現象:
CNN的最后一層卷積輸出的特征圖,對其通道進行加權疊加后,其激活值(ReLU激活后的非零值)所在的區域,即為圖像中的物體所在區域。而將這一疊加后的單通道特征圖覆蓋到輸入圖像上,即可高亮圖像中物體所在位置區域。如圖1中的輸入圖像和輸出圖像所示。
該文章作者將實現這一現象的方法命名為類激活映射,并將特征圖疊加在原始輸入圖像上生成的新圖片命名為類激活圖。
二、CAM有什么用?
CAM一般有兩種用途:
- 可視化模型特征圖,以便觀察模型是通過圖像中的哪些區域特征來區分物體類別的;
- 利用卷積神經網絡分類模型進行弱監督的圖像目標定位。
第一種用途是最直接的用途,根據CAM高亮的圖像區域,可以直觀地解釋CNN是如何區分不同類別的物體的。
對于第二種用途,一般的目標定位方法,都需要專門對圖像中的物體位置區域進行標注,并將標注信息作為圖像標簽的一部分,然后通過訓練帶標簽的圖像和專門的目標定位模型才能實現定位,是一種強監督的方法。而CAM方法不需要物體在圖像中的位置信息,僅僅依靠圖像整體的類別標簽訓練分類模型,即可找到圖像中物體所在的大致位置并高亮之,因此可以作為一種弱監督的目標定位方法。
三、CAM原理
如圖2所示,CNN最后一層卷積層輸出的特征圖是三維的:[C, H, W ],設特征圖的第
其中
表示特征圖的每個輸出通道首先被平均為一個值,
個通道得到
個值,然后這些值再被加權相加得到一個數,這個數就是第
類的置信分數,表征著輸入圖像的類別是
的可能性大小。
表示首先對特征圖的每個通道進行加權求和(
),得到一個二維的特征圖(通道維坍塌),然后再對這個二維特征圖求平均值,得到第
類的置信分數。
由公式(1)的推導可知,先對特征圖進行全局平均池化,再進行加權求和得到類別的置信分數,等價于先對特征圖進行通道維度的加權求和,再進行全局平均池化。
經過這一等價變換,就突顯了特征圖通道加權和的重要性了:一方面,特征圖的通道加權和直接編碼了類別信息;另一方面,也是最重要的,特征圖的通道加權和是二維的,還保留著圖像的空間位置信息。我們可以通過可視化方法觀察到圖像中的相對位置信息與CNN編碼的類別信息的關系。
這里的特征圖的通道加權之和就叫做類別激活圖。
四、CAM的PyTorch實現
本文以PyTorch自帶的ResNet-18為例,分步驟講解并用代碼實現CAM的整個流程和細節。
1.準備工作
首先導入需要用到的包:
import math
import torch
from torch import Tensor
from torch import nn
import torch.nn.functional as F
from typing import Optional, List
import torchvision.transforms as transforms
from PIL import Image
import torchvision.models as models
from torch import Tensor
from matplotlib import cm
from torchvision.transforms.functional import to_pil_image
定義輸入圖片路徑,和保存輸出的類激活圖的路徑:
img_path = '/home/dell/img/1.JPEG' # 輸入圖片的路徑
save_path = '/home/dell/cam/CAM1.png' # 類激活圖保存路徑
定義輸入圖片預處理方式。由于本文用的輸入圖片來自ILSVRC-2012驗證集,因此采用PyTorch官方文檔提供的ImageNet驗證集處理流程:
preprocess = transforms.Compose([transforms.Resize(256),
transforms.CenterCrop(224),
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])])
2.獲取CNN最后一層卷積層的輸出特征圖
本文選用的CNN模型是PyTorch自帶的ResNet-18,首先導入預訓練模型:
net = models.resnet18(pretrained=True).cuda() # 導入模型
由于特征圖是模型前向傳播時的中間變量,不能直接從模型中獲取,需要用到PyTorch提供的hook工具,補課請參考我的這兩篇博客:hook1,hook2。
通過輸出模型(print(net))我們就能看到ResNet-18輸出最后一層特征圖的層為net.layer4(或者net.layer4[1]、net.layer4[1].bn2都可)。我們用hook工具注冊這一層,以便獲得它的輸出特征圖:
feature_map = [] # 建立列表容器,用于盛放輸出特征圖
def forward_hook(module, inp, outp): # 定義hook
feature_map.append(outp) # 把輸出裝入字典feature_map
net.layer4.register_forward_hook(forward_hook) # 對net.layer4這一層注冊前向傳播
做好了hook的定義和注冊工作,現在只需要對輸入圖片進行預處理,然后執行一次模型前向傳播即可獲得CNN最后一層卷積層的輸出特征圖:
orign_img = Image.open(img_path).convert('RGB') # 打開圖片并轉換為RGB模型
img = preprocess(orign_img) # 圖片預處理
img = torch.unsqueeze(img, 0) # 增加batch維度 [1, 3, 224, 224]
with torch.no_grad():
out = net(img.cuda()) # 前向傳播
這時我們想要的特征圖已經裝在列表feature_map中了。我們輸出尺寸來驗證一下:
In [10]: print(feature_map[0].size())
torch.Size([1, 512, 7, 7])
3.獲取權重
CAM使用的權重是全連接輸出層中,對應這張圖像所屬類別的權重。文字表述可能存在歧義或不清楚,直接看本文最上面的圖中全連接層被著色的連接??梢钥吹?,每個連接對應一個權重值,左邊和特征圖的每個通道(全局平均池化后)一一連接,右邊全都連接著輸出類別所對應的那個神經元。
由于我也不知道這張圖的類別標簽,這里假設模型對這張圖像分類正確,我們來獲得其輸出類別所對應的權重:
cls = torch.argmax(out).item() # 獲取預測類別編碼
weights = net._modules.get('fc').weight.data[cls,:] # 獲取類別對應的權重
4.對特征圖的通道進行加權疊加,獲得CAM
cam = (weights.view(*weights.shape, 1, 1) * feature_map[0].squeeze(0)).sum(0)
這里的代碼比較簡單,擴充權重的維度([512, ][512, 1, 1])是為了使之在通道上與特征圖相乘;去除特征圖的batch維([1, 512, 7, 7]
[512, 7, 7])是為了使其維度和weights擴充后的維度相同以相乘。最后在第一維(通道維)上相加求和,得到一個
的類激活圖。
5.對CAM進行ReLU激活和歸一化
這一步有兩個細節需要注意:
- 上步得到的類激活圖像素值分布雜亂,要想確定目標位置,須先進行ReLU激活,將正值保留,負值置零。像素值正值所在的(一個或多個)區域即為目標定位區域。
- 上步獲得的激活圖還只是一個普通矩陣,需要變換成圖像規格,將其值歸一化到[0,1]之間。
我們首先定義歸一化函數:
def _normalize(cams: Tensor) -> Tensor:
"""CAM normalization"""
cams.sub_(cams.flatten(start_dim=-2).min(-1).values.unsqueeze(-1).unsqueeze(-1))
cams.div_(cams.flatten(start_dim=-2).max(-1).values.unsqueeze(-1).unsqueeze(-1))
return cams
然后對類激活圖執行ReLU激活和歸一化,并利用PyTorch的 to_pil_image函數將其轉換為PIL格式以便下步處理:
cam = _normalize(F.relu(cam, inplace=True)).cpu()
mask = to_pil_image(cam.detach().numpy(), mode='F')
將類激活圖轉換成PIL格式是為了方便下一步和輸入圖像融合,因為本例中我們選用的PIL庫將輸入圖像打開,選用PIL庫也是因為PyTorch處理圖像時默認的圖像格式是PIL格式的。
6.將類激活圖覆蓋到輸入圖像上,實現目標定位
這一步也有很多細節需要注意:
- 上步得到的類激活圖只有
的尺寸,想要將其覆蓋在輸入圖像上顯示,就需將其用插值的方法擴大到和輸入圖像相同大小。
- 我們的目的是用類激活圖中被激活(非零值)的位置區域,來高亮原始圖像中相應的位置區域,這一高亮的方法就是將激活圖變換為熱力圖的形式:值越大的像素顏色越紅,值越小的像素顏色越藍。
- 如果直接將熱力圖覆蓋到原始輸入圖像上,會遮蔽圖像中的內容導致不容易觀察,因此需要設置兩個圖像融合的比例(透明度),即在兩種圖像融合在一起時,將原始輸入圖像的像素值權重設置大一些,而把熱力圖的像素值權重設置小一些,這樣就會使生成圖像中原始輸入圖像的內容更加清晰,易于觀察。(mixup方法同理)
- 兩種圖像融合后的像素值會超出圖像規格像素值的范圍[0,1],因此還需要將其轉換為圖像規格。
我們將兩個圖像交疊融合的過程封裝成了函數:
def overlay_mask(img: Image.Image, mask: Image.Image, colormap: str = 'jet', alpha: float = 0.6) -> Image.Image:
"""Overlay a colormapped mask on a background image
Args:
img: background image
mask: mask to be overlayed in grayscale
colormap: colormap to be applied on the mask
alpha: transparency of the background image
Returns:
overlayed image
"""
if not isinstance(img, Image.Image) or not isinstance(mask, Image.Image):
raise TypeError('img and mask arguments need to be PIL.Image')
if not isinstance(alpha, float) or alpha < 0 or alpha >= 1:
raise ValueError('alpha argument is expected to be of type float between 0 and 1')
cmap = cm.get_cmap(colormap)
# Resize mask and apply colormap
overlay = mask.resize(img.size, resample=Image.BICUBIC)
overlay = (255 * cmap(np.asarray(overlay) ** 2)[:, :, 1:]).astype(np.uint8)
# Overlay the image with the mask
overlayed_img = Image.fromarray((alpha * np.asarray(img) + (1 - alpha) * overlay).astype(np.uint8))
return overlayed_img
接下來就是激動人心的時刻了?。?!將類激活圖作為掩碼,以一定的比例覆蓋到原始輸入圖像上,生成類激活圖:
result = overlay_mask(orign_img, mask)
這里的變量result已經是有著PIL圖片格式的類激活圖了,我們可以通過:
result.show()
可視化輸出,也可以通過:
result.save(save_path)
將圖片保存在本地查看。我們在這里展示一下輸入圖像和輸出定位圖像的對比: