pytorch 的unet模型特征图可视化代码

时间: 2023-09-23 16:07:34 浏览: 243
很好,下面是 PyTorch 的 UNet 模型特征图可视化代码: ```python import torch from torch.autograd import Variable import numpy as np import cv2 import matplotlib.pyplot as plt def hook_fn(m, i, o): print(m) print("------------Input Grad------------") print(i) print("------------Output Grad------------") print(o) class Unet(torch.nn.Module): def __init__(self, n_channels, n_classes, bilinear=True): super(Unet, self).__init__() self.n_channels = n_channels self.n_classes = n_classes self.bilinear = bilinear self.inc = DoubleConv(n_channels, 64) self.down1 = Down(64, 128) self.down2 = Down(128, 256) self.down3 = Down(256, 512) factor = 2 if bilinear else 1 self.down4 = Down(512, 1024 // factor) self.up1 = Up(1024, 512 // factor, bilinear) self.up2 = Up(512, 256 // factor, bilinear) self.up3 = Up(256, 128 // factor, bilinear) self.up4 = Up(128, 64, bilinear) self.outc = OutConv(64, n_classes) def forward(self, x): x1 = self.inc(x) x2 = self.down1(x1) x3 = self.down2(x2) x4 = self.down3(x3) x5 = self.down4(x4) x = self.up1(x5, x4) x = self.up2(x, x3) x = self.up3(x, x2) x = self.up4(x, x1) logits = self.outc(x) return logits class DoubleConv(torch.nn.Module): def __init__(self, in_ch, out_ch): super(DoubleConv, self).__init__() self.conv = torch.nn.Sequential( torch.nn.Conv2d(in_ch, out_ch, 3, padding=1), torch.nn.BatchNorm2d(out_ch), torch.nn.ReLU(inplace=True), torch.nn.Conv2d(out_ch, out_ch, 3, padding=1), torch.nn.BatchNorm2d(out_ch), torch.nn.ReLU(inplace=True) ) def forward(self, x): x = self.conv(x) return x class Down(torch.nn.Module): def __init__(self, in_ch, out_ch): super(Down, self).__init__() self.mpconv = torch.nn.Sequential( torch.nn.MaxPool2d(2), DoubleConv(in_ch, out_ch) ) def forward(self, x): x = self.mpconv(x) return x class Up(torch.nn.Module): def __init__(self, in_ch, out_ch, bilinear=True): super(Up, self).__init__() if bilinear: self.up = torch.nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True) else: self.up = torch.nn.ConvTranspose2d(in_ch//2, in_ch//2, 2, stride=2) self.conv = DoubleConv(in_ch, out_ch) def forward(self, x1, x2): x1 = self.up(x1) diffY = x2.size()[2] - x1.size()[2] diffX = x2.size()[3] - x1.size()[3] x1 = torch.nn.functional.pad(x1, (diffX // 2, diffX - diffX//2, diffY // 2, diffY - diffY//2)) x = torch.cat([x2, x1], dim=1) x = self.conv(x) return x class OutConv(torch.nn.Module): def __init__(self, in_ch, out_ch): super(OutConv, self).__init__() self.conv = torch.nn.Conv2d(in_ch, out_ch, 1) def forward(self, x): x = self.conv(x) return x # 加载已经训练好的UNet模型 model = Unet(n_channels = 3, n_classes = 1) model.load_state_dict(torch.load("unet.pth")) model.eval() # 图像预处理 img = cv2.imread("example.jpg") img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) img = cv2.resize(img, (512, 512)) img = np.transpose(img, (2, 0, 1)) img = img.astype(np.float32) / 255. img = Variable(torch.from_numpy(img).unsqueeze(0)) # 注册钩子,获取特征图 features_blobs = [] def hook_feature(module, input, output): features_blobs.append(output.data.cpu().numpy()) model.conv1.register_forward_hook(hook_feature) # 获取并绘制特征图 output = model(img) fea = features_blobs[0] plt.figure(figsize=(10, 10)) plt.subplots_adjust(wspace=0, hspace=0) for idx in range(64): plt.subplot(8, 8, idx + 1) plt.axis('off') plt.imshow(fea[0][idx], cmap='jet') plt.show() ``` 希望这个代码可以帮到你。

相关推荐

最新推荐

recommend-type

pre_o_1csdn63m9a1bs0e1rr51niuu33e.a

pre_o_1csdn63m9a1bs0e1rr51niuu33e.a
recommend-type

matlab建立计算力学课程的笔记和文件.zip

matlab建立计算力学课程的笔记和文件.zip
recommend-type

FT-Prog-v3.12.38.643-FTD USB 工作模式设定及eprom读写

FT_Prog_v3.12.38.643--FTD USB 工作模式设定及eprom读写
recommend-type

matlab基于RRT和人工势场法混合算法的路径规划.zip

matlab基于RRT和人工势场法混合算法的路径规划.zip
recommend-type

matlab基于matlab的两步定位软件定义接收机的开源GNSS直接位置估计插件模块.zip

matlab基于matlab的两步定位软件定义接收机的开源GNSS直接位置估计插件模块.zip
recommend-type

zigbee-cluster-library-specification

最新的zigbee-cluster-library-specification说明文档。
recommend-type

管理建模和仿真的文件

管理Boualem Benatallah引用此版本:布阿利姆·贝纳塔拉。管理建模和仿真。约瑟夫-傅立叶大学-格勒诺布尔第一大学,1996年。法语。NNT:电话:00345357HAL ID:电话:00345357https://theses.hal.science/tel-003453572008年12月9日提交HAL是一个多学科的开放存取档案馆,用于存放和传播科学研究论文,无论它们是否被公开。论文可以来自法国或国外的教学和研究机构,也可以来自公共或私人研究中心。L’archive ouverte pluridisciplinaire
recommend-type

实现实时数据湖架构:Kafka与Hive集成

![实现实时数据湖架构:Kafka与Hive集成](https://img-blog.csdnimg.cn/img_convert/10eb2e6972b3b6086286fc64c0b3ee41.jpeg) # 1. 实时数据湖架构概述** 实时数据湖是一种现代数据管理架构,它允许企业以低延迟的方式收集、存储和处理大量数据。与传统数据仓库不同,实时数据湖不依赖于预先定义的模式,而是采用灵活的架构,可以处理各种数据类型和格式。这种架构为企业提供了以下优势: - **实时洞察:**实时数据湖允许企业访问最新的数据,从而做出更明智的决策。 - **数据民主化:**实时数据湖使各种利益相关者都可
recommend-type

2. 通过python绘制y=e-xsin(2πx)图像

可以使用matplotlib库来绘制这个函数的图像。以下是一段示例代码: ```python import numpy as np import matplotlib.pyplot as plt def func(x): return np.exp(-x) * np.sin(2 * np.pi * x) x = np.linspace(0, 5, 500) y = func(x) plt.plot(x, y) plt.xlabel('x') plt.ylabel('y') plt.title('y = e^{-x} sin(2πx)') plt.show() ``` 运行这段
recommend-type

JSBSim Reference Manual

JSBSim参考手册,其中包含JSBSim简介,JSBSim配置文件xml的编写语法,编程手册以及一些应用实例等。其中有部分内容还没有写完,估计有生之年很难看到完整版了,但是内容还是很有参考价值的。