提交 2547a449 编写于 作者: TensorSense's avatar TensorSense

backward_hook for Grad-CAM

上级 45122761
# coding: utf-8
"""
通过实现Grad-CAM学习module中的forward_hook和backward_hook函数
"""
import cv2
import os
import numpy as np
from PIL import Image
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision.transforms as transforms
class Net(nn.Module):
def __init__(self):
super(Net, self).__init__()
self.conv1 = nn.Conv2d(3, 6, 5)
self.pool1 = nn.MaxPool2d(2, 2)
self.conv2 = nn.Conv2d(6, 16, 5)
self.pool2 = nn.MaxPool2d(2, 2)
self.fc1 = nn.Linear(16 * 5 * 5, 120)
self.fc2 = nn.Linear(120, 84)
self.fc3 = nn.Linear(84, 10)
def forward(self, x):
x = self.pool1(F.relu(self.conv1(x)))
x = self.pool1(F.relu(self.conv2(x)))
x = x.view(-1, 16 * 5 * 5)
x = F.relu(self.fc1(x))
x = F.relu(self.fc2(x))
x = self.fc3(x)
return x
def img_transform(img_in, transform):
"""
将img进行预处理,并转换成模型输入所需的形式—— B*C*H*W
:param img_roi: np.array
:return:
"""
img = img_in.copy()
img = Image.fromarray(np.uint8(img))
img = transform(img)
img = img.unsqueeze(0) # C*H*W --> B*C*H*W
return img
def img_preprocess(img_in):
"""
读取图片,转为模型可读的形式
:param img_in: ndarray, [H, W, C]
:return: PIL.image
"""
img = img_in.copy()
img = cv2.resize(img,(32, 32))
img = img[:, :, ::-1] # BGR --> RGB
transform = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize([0.4948052, 0.48568845, 0.44682974], [0.24580306, 0.24236229, 0.2603115])
])
img_input = img_transform(img, transform)
return img_input
def backward_hook(module, grad_in, grad_out):
grad_block.append(grad_out[0].detach())
def farward_hook(module, input, output):
fmap_block.append(output)
def show_cam_on_image(img, mask, out_dir):
heatmap = cv2.applyColorMap(np.uint8(255*mask), cv2.COLORMAP_JET)
heatmap = np.float32(heatmap) / 255
cam = heatmap + np.float32(img)
cam = cam / np.max(cam)
path_cam_img = os.path.join(out_dir, "cam.jpg")
path_raw_img = os.path.join(out_dir, "raw.jpg")
if not os.path.exists(out_dir):
os.makedirs(out_dir)
cv2.imwrite(path_cam_img, np.uint8(255 * cam))
cv2.imwrite(path_raw_img, np.uint8(255 * img))
def comp_class_vec(ouput_vec, index=None):
"""
计算类向量
:param ouput_vec: tensor
:param index: int,指定类别
:return: tensor
"""
if not index:
index = np.argmax(ouput_vec.cpu().data.numpy())
index = index[np.newaxis, np.newaxis]
index = torch.from_numpy(index)
one_hot = torch.zeros(1, 10).scatter_(1, index, 1)
one_hot.requires_grad = True
class_vec = torch.sum(one_hot * output) # one_hot = 11.8605
return class_vec
def gen_cam(feature_map, grads):
"""
依据梯度和特征图,生成cam
:param feature_map: np.array, in [C, H, W]
:param grads: np.array, in [C, H, W]
:return: np.array, [H, W]
"""
cam = np.zeros(feature_map.shape[1:], dtype=np.float32) # cam shape (H, W)
weights = np.mean(grads, axis=(1, 2)) #
for i, w in enumerate(weights):
cam += w * feature_map[i, :, :]
cam = np.maximum(cam, 0)
cam = cv2.resize(cam, (32, 32))
cam -= np.min(cam)
cam /= np.max(cam)
return cam
if __name__ == '__main__':
BASE_DIR = os.path.dirname(os.path.abspath(__file__))
path_img = os.path.join(BASE_DIR, "../../Data/cam_img/", "test_img_1.png")
path_net = os.path.join(BASE_DIR, "../../Data/", "net_params_72p.pkl")
output_dir = os.path.join(BASE_DIR, "../../Result/backward_hook_cam/")
classes = ('plane', 'car', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck')
fmap_block = list()
grad_block = list()
# 图片读取;网络加载
img = cv2.imread(path_img, 1) # H*W*C
img_input = img_preprocess(img)
net = Net()
net.load_state_dict(torch.load(path_net))
# 注册hook
net.conv2.register_forward_hook(farward_hook)
net.conv2.register_backward_hook(backward_hook)
# forward
output = net(img_input)
idx = np.argmax(output.cpu().data.numpy())
print("predict: {}".format(classes[idx]))
# backward
net.zero_grad()
class_loss = comp_class_vec(output)
class_loss.backward()
# 生成cam
grads_val = grad_block[0].cpu().data.numpy().squeeze()
fmap = fmap_block[0].cpu().data.numpy().squeeze()
cam = gen_cam(fmap, grads_val)
# 保存cam图片
img_show = np.float32(cv2.resize(img, (32, 32))) / 255
show_cam_on_image(img_show, cam, output_dir)
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册