基于pytorch写一个hook提取模型特定层特征图并可视化CAM的代码
时间: 2024-05-08 09:15:11 浏览: 16
以下是基于pytorch写的hook提取模型特定层特征图并可视化CAM的代码:
```python
import torch
import torch.nn as nn
import cv2
import numpy as np
class CAM():
def __init__(self, model, target_layer):
self.model = model
self.target_layer = target_layer
self.features = []
self.grads = []
self.hook_fn = None
def hook(self, module, input, output):
self.features.append(output.cpu().data.numpy())
def hook_backward(self, module, grad_input, grad_output):
self.grads.append(grad_output[0].cpu().data.numpy())
def get_cam(self, input_image, class_idx=None):
self.features = []
self.grads = []
self.hook_fn = self.target_layer.register_forward_hook(self.hook)
hook_fn_backward = self.target_layer.register_backward_hook(self.hook_backward)
input_image = input_image.to(device)
self.model.zero_grad()
output = self.model(input_image)
if class_idx is None:
class_idx = np.argmax(output.cpu().data.numpy())
one_hot = np.zeros((1, output.size()[-1]), dtype=np.float32)
one_hot[0][class_idx] = 1
one_hot = torch.from_numpy(one_hot).requires_grad_(True)
one_hot = torch.sum(one_hot * output)
self.model.zero_grad()
one_hot.backward()
grads_val = self.grads[-1]
target = self.features[-1]
weights = np.mean(grads_val, axis=(2, 3))[0, :]
cam = np.zeros(target.shape[2:], dtype=np.float32)
for i, w in enumerate(weights):
cam += w * target[0, i, :, :]
cam = np.maximum(cam, 0)
cam = cv2.resize(cam, input_image.shape[2:])
cam = cam - np.min(cam)
cam = cam / np.max(cam)
self.hook_fn.remove()
hook_fn_backward.remove()
return cam
# 使用示例:
class Net(nn.Module):
def __init__(self):
super(Net, self).__init__()
self.conv1 = nn.Conv2d(3, 32, kernel_size=3, padding=1)
self.conv2 = nn.Conv2d(32, 64, kernel_size=3, padding=1)
self.conv3 = nn.Conv2d(64, 128, kernel_size=3, padding=1)
self.fc1 = nn.Linear(128 * 8 * 8, 256)
self.fc2 = nn.Linear(256, 10)
def forward(self, x):
x = nn.functional.relu(self.conv1(x))
x = nn.functional.max_pool2d(x, 2)
x = nn.functional.relu(self.conv2(x))
x = nn.functional.max_pool2d(x, 2)
x = nn.functional.relu(self.conv3(x))
x = nn.functional.max_pool2d(x, 2)
x = x.view(-1, 128 * 8 * 8)
x = nn.functional.relu(self.fc1(x))
x = self.fc2(x)
return x
model = Net()
target_layer = model.conv3
cam = CAM(model, target_layer)
# 加载图像
image_path = 'test.jpg'
image = cv2.imread(image_path)
image = cv2.resize(image, (32, 32))
image = np.transpose(image, (2, 0, 1))
image = image.astype(np.float32) / 255.
image = torch.from_numpy(image)
image = image.unsqueeze(0)
# 生成CAM
cam_map = cam.get_cam(image)
# 可视化CAM
heatmap = cv2.applyColorMap(np.uint8(255 * cam_map), cv2.COLORMAP_JET)
heatmap = np.float32(heatmap) / 255
cam_image = heatmap + np.float32(image[0])
cam_image = cam_image / np.max(cam_image)
cv2.imshow('CAM', cam_image)
cv2.waitKey(0)
```
在上面的代码中,我们定义了一个CAM类,用来提取模型特定层的特征图,并生成对应的CAM图像。CAM类中包含了一个hook函数,用来提取目标层的特征图,以及一个hook_backward函数,用来提取特征图对应的梯度。在get_cam函数中,我们首先将输入图像经过模型前向传播,然后根据输出结果确定目标类别。接着,我们通过反向传播计算目标类别对应的特征图梯度,并利用这个梯度生成CAM图像。最后,我们将CAM图像和原始图像叠加起来,生成可视化的CAM图像。