pytorch 的unet模型特征图可视化代码
时间: 2023-09-23 21:07:34 浏览: 444
pytorch 模型可视化的例子
很好,下面是 PyTorch 的 UNet 模型特征图可视化代码:
```python
import torch
from torch.autograd import Variable
import numpy as np
import cv2
import matplotlib.pyplot as plt
def hook_fn(m, i, o):
print(m)
print("------------Input Grad------------")
print(i)
print("------------Output Grad------------")
print(o)
class Unet(torch.nn.Module):
def __init__(self, n_channels, n_classes, bilinear=True):
super(Unet, self).__init__()
self.n_channels = n_channels
self.n_classes = n_classes
self.bilinear = bilinear
self.inc = DoubleConv(n_channels, 64)
self.down1 = Down(64, 128)
self.down2 = Down(128, 256)
self.down3 = Down(256, 512)
factor = 2 if bilinear else 1
self.down4 = Down(512, 1024 // factor)
self.up1 = Up(1024, 512 // factor, bilinear)
self.up2 = Up(512, 256 // factor, bilinear)
self.up3 = Up(256, 128 // factor, bilinear)
self.up4 = Up(128, 64, bilinear)
self.outc = OutConv(64, n_classes)
def forward(self, x):
x1 = self.inc(x)
x2 = self.down1(x1)
x3 = self.down2(x2)
x4 = self.down3(x3)
x5 = self.down4(x4)
x = self.up1(x5, x4)
x = self.up2(x, x3)
x = self.up3(x, x2)
x = self.up4(x, x1)
logits = self.outc(x)
return logits
class DoubleConv(torch.nn.Module):
def __init__(self, in_ch, out_ch):
super(DoubleConv, self).__init__()
self.conv = torch.nn.Sequential(
torch.nn.Conv2d(in_ch, out_ch, 3, padding=1),
torch.nn.BatchNorm2d(out_ch),
torch.nn.ReLU(inplace=True),
torch.nn.Conv2d(out_ch, out_ch, 3, padding=1),
torch.nn.BatchNorm2d(out_ch),
torch.nn.ReLU(inplace=True)
)
def forward(self, x):
x = self.conv(x)
return x
class Down(torch.nn.Module):
def __init__(self, in_ch, out_ch):
super(Down, self).__init__()
self.mpconv = torch.nn.Sequential(
torch.nn.MaxPool2d(2),
DoubleConv(in_ch, out_ch)
)
def forward(self, x):
x = self.mpconv(x)
return x
class Up(torch.nn.Module):
def __init__(self, in_ch, out_ch, bilinear=True):
super(Up, self).__init__()
if bilinear:
self.up = torch.nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True)
else:
self.up = torch.nn.ConvTranspose2d(in_ch//2, in_ch//2, 2, stride=2)
self.conv = DoubleConv(in_ch, out_ch)
def forward(self, x1, x2):
x1 = self.up(x1)
diffY = x2.size()[2] - x1.size()[2]
diffX = x2.size()[3] - x1.size()[3]
x1 = torch.nn.functional.pad(x1, (diffX // 2, diffX - diffX//2,
diffY // 2, diffY - diffY//2))
x = torch.cat([x2, x1], dim=1)
x = self.conv(x)
return x
class OutConv(torch.nn.Module):
def __init__(self, in_ch, out_ch):
super(OutConv, self).__init__()
self.conv = torch.nn.Conv2d(in_ch, out_ch, 1)
def forward(self, x):
x = self.conv(x)
return x
# 加载已经训练好的UNet模型
model = Unet(n_channels = 3, n_classes = 1)
model.load_state_dict(torch.load("unet.pth"))
model.eval()
# 图像预处理
img = cv2.imread("example.jpg")
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
img = cv2.resize(img, (512, 512))
img = np.transpose(img, (2, 0, 1))
img = img.astype(np.float32) / 255.
img = Variable(torch.from_numpy(img).unsqueeze(0))
# 注册钩子,获取特征图
features_blobs = []
def hook_feature(module, input, output):
features_blobs.append(output.data.cpu().numpy())
model.conv1.register_forward_hook(hook_feature)
# 获取并绘制特征图
output = model(img)
fea = features_blobs[0]
plt.figure(figsize=(10, 10))
plt.subplots_adjust(wspace=0, hspace=0)
for idx in range(64):
plt.subplot(8, 8, idx + 1)
plt.axis('off')
plt.imshow(fea[0][idx], cmap='jet')
plt.show()
```
希望这个代码可以帮到你。
阅读全文