AlexNet热力图GradCAM代码实现
时间: 2024-06-15 11:05:47 浏览: 176
AlexNet是一种经典的卷积神经网络模型,而GradCAM(Gradient-weighted Class Activation Mapping)是一种可视化方法,用于理解卷积神经网络在图像分类任务中的决策过程。下面是AlexNet热力图GradCAM代码实现的步骤:
1. 导入必要的库和模块:
```python
import torch
import torch.nn as nn
import torch.nn.functional as F
from torchvision.models import alexnet
from torchvision import transforms
from PIL import Image
import matplotlib.pyplot as plt
```
2. 加载预训练的AlexNet模型:
```python
model = alexnet(pretrained=True)
```
3. 定义GradCAM类:
```python
class GradCAM:
def __init__(self, model, target_layer):
self.model = model
self.target_layer = target_layer
self.feature_maps = None
self.gradient = None
self.model.eval()
self.hook_layers()
def hook_layers(self):
def hook_fn(module, input, output):
self.feature_maps = output.detach()
hook_layer = self.model._modules[self.target_layer]
hook_layer.register_forward_hook(hook_fn)
def compute_gradient(self):
def hook_fn(module, grad_input, grad_output):
self.gradient = grad_output[0].detach()
hook_layer = self.model._modules[self.target_layer]
hook_layer.register_backward_hook(hook_fn)
def generate_heatmap(self, input_image, class_index):
input_image.requires_grad_()
model_output = self.model(input_image)
self.model.zero_grad()
one_hot_output = torch.zeros_like(model_output)
one_hot_output[0][class_index] = 1
model_output.backward(gradient=one_hot_output, retain_graph=True)
weights = F.adaptive_avg_pool2d(self.gradient, 1)
heatmap = torch.mul(self.feature_maps, weights).sum(dim=1, keepdim=True)
heatmap = F.relu(heatmap)
heatmap = F.interpolate(heatmap, size=(input_image.size(2), input_image.size(3)), mode='bilinear', align_corners=False)
heatmap = heatmap.squeeze()
heatmap = heatmap.cpu().numpy()
return heatmap
```
4. 加载图像并进行预处理:
```python
image_path = 'path_to_image.jpg'
image = Image.open(image_path).convert('RGB')
preprocess = transforms.Compose([
transforms.Resize((224, 224)),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])
input_image = preprocess(image).unsqueeze(0)
```
5. 创建GradCAM实例并生成热力图:
```python
gradcam = GradCAM(model, target_layer='features')
heatmap = gradcam.generate_heatmap(input_image, class_index=0)
```
6. 可视化热力图:
```python
plt.imshow(image)
plt.imshow(heatmap, alpha=0.5, cmap='jet')
plt.axis('off')
plt.show()
```
以上是AlexNet热力图GradCAM代码实现的步骤。
阅读全文