GradCamをPyTorchを用いて実装する

CNNの判断の根拠となった部分を可視化する方法としてGradCamが提案されています。
今回はpytorchで提供されている学習済みのVGG16を用いてGradCamの実装を行い、判断根拠の可視化を行います。

# 使用するライブラリの読み込み
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch import optim
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
from torchvision import datasets, transforms
from torchvision import models
from tqdm import tqdm_notebook as tqdm
from PIL import Image
import cv2
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

学習済みモデル(VGG16)の読み込み

以下でpytorchの学習済みモデルの読み込みを行います。

feature_extractor = models.vgg16(pretrained=True).features
classifier = models.vgg16(pretrained=True).classifier

読み込んだモデルをそれぞれ確認します。

feature_extractor
Sequential(
  (0): Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (1): ReLU(inplace=True)
  (2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (3): ReLU(inplace=True)
  (4): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  (5): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (6): ReLU(inplace=True)
  (7): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (8): ReLU(inplace=True)
  (9): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  (10): Conv2d(128, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (11): ReLU(inplace=True)
  (12): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (13): ReLU(inplace=True)
  (14): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (15): ReLU(inplace=True)
  (16): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  (17): Conv2d(256, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (18): ReLU(inplace=True)
  (19): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (20): ReLU(inplace=True)
  (21): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (22): ReLU(inplace=True)
  (23): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  (24): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (25): ReLU(inplace=True)
  (26): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (27): ReLU(inplace=True)
  (28): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (29): ReLU(inplace=True)
  (30): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
)
classifier
Sequential(
  (0): Linear(in_features=25088, out_features=4096, bias=True)
  (1): ReLU(inplace=True)
  (2): Dropout(p=0.5, inplace=False)
  (3): Linear(in_features=4096, out_features=4096, bias=True)
  (4): ReLU(inplace=True)
  (5): Dropout(p=0.5, inplace=False)
  (6): Linear(in_features=4096, out_features=1000, bias=True)
)

feature_extractorは224x224x3の画像を7x7x512の特徴マップを生成するモデル。
classifierはfeature_extractorにより得られた特徴マップから1000カテゴリに分類するモデル。
また上記の1000カテゴリはImageNetを利用しています

学習済みモデル(VGG16)による予測

可視化を行う前に実際に分類を行ってみます。
(注)使用するデータは正規化をする必要があります。

# 次の平均、分散を用いて正規化
normalize = transforms.Normalize(
    mean=[0.485, 0.456, 0.406],
    std=[0.229, 0.224, 0.225])
# 画像を256x256にリサイズ → 中央の224を切り取り
preprocess = transforms.Compose([
    transforms.Resize(256),
    transforms.CenterCrop(224),
    transforms.ToTensor(),
    normalize
])

画像は次の猫の画像を利用します。

img = Image.open('/content/drive/My Drive/Colab Notebooks/cat_img.jpg')
img

# 画像の正規化
img_tensor = preprocess(img)
# 各モデルを予測モードにする
feature_extractor = feature_extractor.eval()
classifier = classifier.eval()
# feature_extractorにより特徴マップを得る
feature = feature_extractor(img_tensor.view(-1,3,224,224))
print('特徴マップのサイズは {}'.format(feature.shape))
特徴マップのサイズは torch.Size([1, 512, 7, 7])
# classifierにより分類を行います
predict = classifier(feature.view(-1,512*7*7))
print('予測されたクラスは {}'.format(torch.argmax(predict,1)))
予測されたクラスは tensor([281])

クラス281はtabby cat(トラ猫)なので、分類が正しく行えました。

GradCamの実装

いよいよ本記事の目的である、GradCamの実装を行います。

# グレースケールをヒートマップにする関数
# 得られたグレースケールの注目箇所をヒートマップに変換する際に用います
def toHeatmap(x):
    x = (x*255).reshape(-1)
    cm = plt.get_cmap('jet')
    x = np.array([cm(int(np.round(xi)))[:3] for xi in x])
    return x.reshape(224,224,3)

以下は論文より引用しています。
\alpha^c_k = \frac{1}{Z} \sum_i \sum_j \frac{\partial y^c}{\partial A^k_ij }

A^k_{ij}
: 得られた特徴マップのkチャネル、(i,j)ピクセルを意味します。

y^c
:は予測されたcクラスの出力

上記を用いて計算される \alpha^c_k は特徴マップの各チャネルの重みを計算しています。

上記で得られた各チャネルの重みを利用して次のように注目度を計算します。

L_{Grad-CAM}^c = ReLU(\sum_k \alpha^c_k A^k)
feature = feature_extractor(img_tensor.view(-1,3,224,224)) #特徴マップを計算
feature = feature.clone().detach().requires_grad_(True) #勾配を計算するようにコピー
y_pred = classifier(feature.view(-1,512*7*7)) #予測を行う
y_pred[0][torch.argmax(y_pred)].backward() # 予測でもっとも高い値をとったクラスの勾配を計算
# 以下は上記の式に倣って計算しています
alpha = torch.mean(feature.grad.view(512,7*7),1)
feature = feature.view(512,7,7)
L = F.relu(torch.sum(feature*alpha.view(-1,1,1),0)).cpu().detach().numpy()
# (0,1)になるように正規化
L_min = np.min(L)
L_max = np.max(L - L_min)
L = (L - L_min)/L_max
# 得られた注目度をヒートマップに変換
L = toHeatmap(cv2.resize(L,(224,224)))
# 画像の正規化を戻すのに利用します
mean = torch.tensor([0.485, 0.456, 0.406]).view(3,1,1)
std = torch.tensor([0.229, 0.224, 0.225]).view(3,1,1)

予測に用いた画像は次のようになっています(正規化は戻してあります)

plt.figure(figsize=(10,10))
plt.imshow((img_tensor*std + mean).permute(1,2,0).cpu().detach().numpy())

得られたヒートマップを上記の画像に重ねて表示すると次のようになります。

img1 = (img_tensor*std + mean).permute(1,2,0).cpu().detach().numpy()
img2 = L
alpha = 0.3
blended = img1*alpha + img2*(1-alpha)
# 結果を表示する。
plt.figure(figsize=(10,10))
plt.imshow(blended)
plt.axis('off')
plt.show()

耳と胴体をみて判定していることがわかります。

今回はGradCamを用いてCNNの注目箇所の可視化を行いました。
また次回attention等を用いた可視化技術の実装ができたらいいなと思っております。