pytorch图片分类,使用cutmix进行数据增强,要求绘制gradcam热力图的全部代码
时间: 2023-06-06 16:07:44 浏览: 125
好的,下面是基于PyTorch和CutMix进行数据增强和绘制GradCAM热力图的全部代码:
```python
import torch
import torch.nn as nn
import torch.nn.functional as F
from torchvision import models, transforms
from PIL import Image
import matplotlib.pyplot as plt
class CutMixTransform:
"""
CutMix augmentation
Reference: https://arxiv.org/pdf/1905.04899.pdf
"""
def __init__(self, alpha=1.0):
self.alpha = alpha
def __call__(self, image, target):
r = np.random.rand(1)
if r < 0.5:
return image, target
w, h = image.size
cut_rat = np.sqrt(1. - self.alpha)
cut_w = np.int(w * cut_rat)
cut_h = np.int(h * cut_rat)
cx = np.random.randint(w)
cy = np.random.randint(h)
bbx1 = np.clip(cx - cut_w // 2, 0, w)
bby1 = np.clip(cy - cut_h // 2, 0, h)
bbx2 = np.clip(cx + cut_w // 2, 0, w)
bby2 = np.clip(cy + cut_h // 2, 0, h)
image = image.copy()
image.paste(image.crop((bbx1, bby1, bbx2, bby2)), (bbx1, bby1, bbx2, bby2))
target_ = target.copy()
target = [target, target_]
return image, target
class Model(nn.Module):
"""
Pretrained ResNet50 model for image classification
"""
def __init__(self, num_classes):
super().__init__()
self.resnet = models.resnet50(pretrained=True)
self.resnet.fc = nn.Linear(2048, num_classes)
def forward(self, x):
x = self.resnet.conv1(x)
x = self.resnet.bn1(x)
x = self.resnet.relu(x)
x = self.resnet.maxpool(x)
x = self.resnet.layer1(x)
x = self.resnet.layer2(x)
x = self.resnet.layer3(x)
x = self.resnet.layer4(x)
x = self.resnet.avgpool(x)
x = torch.flatten(x, 1)
x = self.resnet.fc(x)
return x
def inference(model, image_path):
"""
Perform inference on single image
"""
image = Image.open(image_path)
transform = transforms.Compose([
transforms.Resize(256),
transforms.CenterCrop(224),
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])
input_tensor = transform(image).unsqueeze(0)
output_tensor = model(input_tensor)
output_probs = F.softmax(output_tensor, dim=1)
output_label = torch.argmax(output_probs, dim=1)
return input_tensor, output_probs, output_label
def gradcam(model, input_tensor, class_idx):
"""
Calculate GradCAM heatmap
"""
model.eval()
feature_maps, logits = model(input_tensor.cuda())
logits[0, class_idx].backward()
gradients = model.resnet.layer4[2].conv3.weight.grad
pooled_gradients = torch.mean(gradients, dim=[0, 2, 3])
feature_maps = feature_maps.permute(0, 2, 3, 1)
heatmap = torch.zeros_like(feature_maps[:, :, :, 0])
for i in range(pooled_gradients.shape[0]):
heatmap += (pooled_gradients[i] * feature_maps[:, :, :, i])
return heatmap
# Load model
model = Model(num_classes=10)
# Load image and perform inference
image_path = "cat.jpg"
input_tensor, output_probs, output_label = inference(model, image_path)
# Choose a random class label to visualize
class_idx = torch.randint(0, 10, size=(1,))[0].item()
class_name = ["airplane", "automobile", "bird", "cat", "deer", "dog", "frog", "horse", "ship", "truck"][class_idx]
# Perform CutMix augmentation and inference
cutmix_transforms = CutMixTransform(alpha=1.0)
image_cutmix, target_cutmix = cutmix_transforms(Image.open(image_path).convert("RGB"), class_idx)
input_tensor_cutmix, output_probs_cutmix, output_label_cutmix = inference(model, image_cutmix)
# Calculate GradCAM heatmap
heatmap = gradcam(model, input_tensor_cutmix.cuda(), target_cutmix)
# Plot original image and GradCAM heatmap
fig, ax = plt.subplots(1, 2, figsize=(15, 5))
ax[0].imshow(Image.open(image_path))
ax[0].set_title(f"True: cat\nPred: {class_name}")
ax[1].imshow(heatmap.detach().cpu().numpy(), cmap="jet")
ax[1].set_title(f"GradCAM for {class_name}")
plt.show()
```
在这个示例中,我们首先定义了一个`CutMixTransform`类来实现CutMix数据增强,然后定义了一个预训练的ResNet50模型用于图像分类。我们对一张测试图像进行预测,并从预测结果中随机选择一个类别,然后利用`CutMixTransform`对原始图像进行数据增强,并在增强后的图像上进行预测。然后,我们将增强后的图像和原始图像都用于计算GradCAM热力图,并展示出来。
阅读全文