生成基于自己创建的CNN回归模型的Grad-CAM可视化类激活的热力图 PyTorch实现代码
时间: 2024-01-24 18:16:52 浏览: 276
以下是基于自己创建的CNN回归模型的Grad-CAM可视化类激活的热力图的PyTorch实现代码:
```python
import torch
import torch.nn.functional as F
import cv2
import numpy as np
class GradCAM:
def __init__(self, model, target_layer):
self.model = model
self.target_layer = target_layer
self.feature_maps = None
self.gradient = None
self.activation_maps = None
def forward(self, x):
self.feature_maps = []
self.gradient = []
for name, module in self.model.named_modules():
x = module(x)
if name == self.target_layer:
x.register_hook(self.gradient_hook)
self.feature_maps.append(x)
elif "conv" in name:
self.feature_maps.append(x)
self.activation_maps = x
return self.activation_maps
def gradient_hook(self, grad):
self.gradient.append(grad)
def backward(self):
gradient = self.gradient[0]
feature_maps = self.feature_maps[-1]
batch_size, channel, height, width = feature_maps.shape
weights = F.adaptive_avg_pool2d(gradient, (1, 1))
weights = weights.view(batch_size, channel)
activation_maps = feature_maps.view(batch_size, channel, -1)
weights = weights.unsqueeze(-1)
cam = torch.bmm(activation_maps, weights)
cam = F.relu(cam)
cam = cam.view(batch_size, 1, height, width)
cam = F.interpolate(cam, size=(224, 224), mode='bilinear', align_corners=False)
cam = cam.squeeze()
return cam.detach().cpu().numpy()
model = YourCNNModel()
gradcam = GradCAM(model, 'conv2') # target_layer is the layer you want to visualize
# input image
img = cv2.imread('image.jpg')
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
img = cv2.resize(img, (224, 224))
img = np.transpose(img, (2, 0, 1))
img = np.expand_dims(img, axis=0)
img = torch.from_numpy(img).float()
# forward pass and backward pass
output = model(img)
cam = gradcam.forward(img)
cam = gradcam.backward()
# plot the heat map
plt.imshow(img.squeeze().permute(1, 2, 0))
plt.imshow(cam, alpha=0.5, cmap='jet')
plt.show()
```
其中,`GradCAM`类实现了Grad-CAM算法的前向传播和反向传播,`model`是你自己创建的CNN回归模型,`target_layer`是你想要可视化的卷积层名称。在使用时,需要将输入图像转换为PyTorch张量,然后调用`gradcam.forward()`得到卷积层的特征图和梯度,再调用`gradcam.backward()`得到类激活热力图。最后,将输入图像和类激活热力图叠加在一起即可得到可视化结果。
阅读全文