生成Pytorch实现 Grad-CAM 基于 VGG16模型的代码
时间: 2024-01-22 12:21:32 浏览: 126
pytorch获取vgg16-feature层输出的例子
5星 · 资源好评率100%
以下是基于VGG16模型的Grad-CAM PyTorch实现代码:
```python
import torch
import torch.nn as nn
import torch.nn.functional as F
from torchvision import models
from torchvision import transforms
from PIL import Image
class VGG16(nn.Module):
def __init__(self):
super(VGG16, self).__init__()
self.features = models.vgg16(pretrained=True).features
self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
self.classifier = nn.Sequential(
nn.Linear(512 * 7 * 7, 4096),
nn.ReLU(inplace=True),
nn.Dropout(),
nn.Linear(4096, 4096),
nn.ReLU(inplace=True),
nn.Dropout(),
nn.Linear(4096, 1000),
)
def forward(self, x):
x = self.features(x)
x = self.avgpool(x)
x = torch.flatten(x, 1)
x = self.classifier(x)
return x
class GradCAM:
def __init__(self, model):
self.model = model.eval()
self.feature_maps = []
self.gradient_maps = []
# Registering hooks for feature maps and gradient maps
self.model.features.register_forward_hook(self.save_feature_maps)
self.model.features.register_backward_hook(self.save_gradient_maps)
def save_feature_maps(self, module, input, output):
# Save feature maps during forward pass
self.feature_maps.append(output)
def save_gradient_maps(self, module, grad_input, grad_output):
# Save gradient maps during backward pass
self.gradient_maps.append(grad_output[0])
def forward(self, x):
return self.model(x)
def backward(self, idx):
# Calculate gradients of the output with respect to feature maps
self.model.zero_grad()
grad_output = torch.zeros_like(self.gradient_maps[-1])
grad_output[0][idx] = 1
self.gradient_maps[-1].backward(gradient=grad_output)
def generate(self, x, idx):
# Forward pass to get the predicted class
self.forward(x)
# Backward pass to get the gradients
self.backward(idx)
# Pool the gradients over the feature maps and normalize
pooled_gradients = torch.mean(self.gradient_maps[-1], dim=[2, 3])
feature_maps = self.feature_maps[-1]
for i in range(feature_maps.shape[1]):
feature_maps[:, i, :, :] *= pooled_gradients[i]
heatmap = torch.mean(feature_maps, dim=1).squeeze().detach().numpy()
heatmap = np.maximum(heatmap, 0)
heatmap /= np.max(heatmap)
# Resize the heatmap to match the input image size
heatmap = cv2.resize(heatmap, (x.shape[3], x.shape[2]))
# Convert heatmap to RGB
heatmap = np.uint8(255 * heatmap)
heatmap = cv2.applyColorMap(heatmap, cv2.COLORMAP_JET)
# Superimpose the heatmap on the input image
superimposed_img = np.uint8(0.5 * x[0].permute(1, 2, 0).detach().numpy() + 0.5 * heatmap)
return superimposed_img
# Load the pre-trained VGG16 model
model = VGG16()
# Create GradCAM object
gradcam = GradCAM(model)
# Load the input image
img = Image.open('input.jpg').convert('RGB')
# Preprocess the input image
transform = transforms.Compose([
transforms.Resize((224, 224)),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])
input_tensor = transform(img).unsqueeze(0)
# Get the predicted class index
output = gradcam.forward(input_tensor)
predicted_idx = torch.argmax(output).item()
# Generate the Grad-CAM heatmap
cam = gradcam.generate(input_tensor, predicted_idx)
# Save the output image
output_img = Image.fromarray(cam)
output_img.save('output.jpg')
```
这段代码包括了VGG16模型的定义、Grad-CAM的实现、输入图像的预处理以及结果图像的保存。你只需将`input.jpg`替换为你自己的输入图像,运行代码即可得到Grad-CAM可视化结果图像`output.jpg`。
阅读全文