一维cnn+Grad-CAM pytorch
时间: 2023-12-13 21:33:51 浏览: 165
以下是一维CNN+Grad-CAM的PyTorch实现:
```python
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.autograd import Function
from torchvision import models
class GradCAM:
def __init__(self, model, target_layer):
self.model = model
self.target_layer = target_layer
self.feature_maps = None
self.gradient = None
def save_feature_maps(self, module, input, output):
self.feature_maps = output.detach()
def save_gradient(self, grad):
self.gradient = grad.detach()
def forward(self, x):
for module_pos, module in self.model.features._modules.items():
x = module(x)
if int(module_pos) == self.target_layer:
x.register_hook(self.save_gradient)
self.feature_maps = x
x = F.avg_pool1d(x, kernel_size=x.size()[2:])
x = x.view(x.size(0), -1)
x = self.model.classifier(x)
return x
def backward(self):
self.model.zero_grad()
self.gradient.backward(retain_graph=True)
def generate(self):
weights = F.adaptive_avg_pool1d(self.gradient, 1)
saliency_map = (weights * self.feature_maps).sum(dim=1, keepdim=True)
saliency_map = F.relu(saliency_map)
saliency_map = F.interpolate(saliency_map, size=(100,), mode='linear', align_corners=False)
saliency_map_min, saliency_map_max = saliency_map.min(), saliency_map.max()
if saliency_map_max - saliency_map_min > 0:
saliency_map = (saliency_map - saliency_map_min) / (saliency_map_max - saliency_map_min)
return saliency_map
class Net(nn.Module):
def __init__(self):
super(Net, self).__init__()
self.features = nn.Sequential(
nn.Conv1d(1, 10, kernel_size=5),
nn.ReLU(inplace=True),
nn.MaxPool1d(kernel_size=2),
nn.Conv1d(10, 20, kernel_size=5),
nn.ReLU(inplace=True),
nn.MaxPool1d(kernel_size=2),
)
self.classifier = nn.Sequential(
nn.Linear(20 * 22, 50),
nn.ReLU(inplace=True),
nn.Linear(50, 2),
)
def forward(self, x):
x = self.features(x)
x = x.view(x.size(0), -1)
x = self.classifier(x)
return x
model = Net()
grad_cam = GradCAM(model, target_layer=4)
# 训练模型并保存
...
# 加载模型并进行Grad-CAM可视化
model.load_state_dict(torch.load('model.pth'))
model.eval()
input = torch.randn(1, 1, 100)
output = model(input)
pred = output.argmax(dim=1)
grad_cam.forward(input)
grad_cam.backward()
saliency_map = grad_cam.generate()
# 可视化
import matplotlib.pyplot as plt
plt.plot(input[0, 0].numpy())
plt.plot(saliency_map[0, 0].numpy())
plt.show()
```
阅读全文