给我画一下unet神经网络图 用代码
时间: 2024-11-30 19:18:53 浏览: 5
绘制UNet神经网络图可以使用 `torchviz` 库来生成计算图。以下是如何使用 `torchviz` 来绘制 `UNet` 和 `AttentionUNet` 模型的计算图的示例代码:
首先,确保你已经安装了 `torchviz` 库。如果没有安装,可以使用以下命令进行安装:
```bash
pip install torchviz
```
然后,使用以下代码来绘制 `UNet` 和 `AttentionUNet` 模型的计算图:
```python
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
from PIL import Image
import os
import argparse
import glob
import matplotlib.pyplot as plt
from tqdm import tqdm
import numpy as np
from torchviz import make_dot
# 定义 UNet 模型
class UNet(nn.Module):
def __init__(self):
super(UNet, self).__init__()
self.encoder = nn.Sequential(
nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1),
nn.ReLU(),
nn.MaxPool2d(kernel_size=2, stride=2),
nn.Conv2d(64, 128, kernel_size=3, padding=1),
nn.ReLU(),
nn.MaxPool2d(2),
)
self.decoder = nn.Sequential(
nn.ConvTranspose2d(128, 64, kernel_size=2, stride=2),
nn.ReLU(),
nn.ConvTranspose2d(64, 3, kernel_size=2, stride=2),
nn.Sigmoid(),
)
def forward(self, x):
x = self.encoder(x)
x = self.decoder(x)
return x
# 定义 AttentionBlock 模块
class AttentionBlock(nn.Module):
def __init__(self, F_g, F_l, F_int):
super(AttentionBlock, self).__init__()
self.W_g = nn.Sequential(
nn.Conv2d(F_g, F_int, kernel_size=1, stride=1, padding=0, bias=True),
nn.BatchNorm2d(F_int)
)
self.W_x = nn.Sequential(
nn.Conv2d(F_l, F_int, kernel_size=1, stride=1, padding=0, bias=True),
nn.BatchNorm2d(F_int)
)
self.psi = nn.Sequential(
nn.Conv2d(F_int, 1, kernel_size=1, stride=1, padding=0, bias=True),
nn.BatchNorm2d(1),
nn.Sigmoid()
)
self.relu = nn.ReLU(inplace=True)
def forward(self, g, x):
g1 = self.W_g(g)
x1 = self.W_x(x)
psi = self.relu(g1 + x1)
psi = self.psi(psi)
return x * psi
# 定义 AttentionUNet 模型
class AttentionUNet(nn.Module):
def __init__(self):
super(AttentionUNet, self).__init__()
self.encoder1 = self.conv_block(3, 64)
self.encoder2 = self.conv_block(64, 128)
self.encoder3 = self.conv_block(128, 256)
self.encoder4 = self.conv_block(256, 512)
self.bottleneck = self.conv_block(512, 1024)
self.upconv4 = nn.ConvTranspose2d(1024, 512, kernel_size=2, stride=2)
self.att4 = AttentionBlock(F_g=512, F_l=512, F_int=256)
self.decoder4 = self.conv_block(1024, 512)
self.upconv3 = nn.ConvTranspose2d(512, 256, kernel_size=2, stride=2)
self.att3 = AttentionBlock(F_g=256, F_l=256, F_int=128)
self.decoder3 = self.conv_block(512, 256)
self.upconv2 = nn.ConvTranspose2d(256, 128, kernel_size=2, stride=2)
self.att2 = AttentionBlock(F_g=128, F_l=128, F_int=64)
self.decoder2 = self.conv_block(256, 128)
self.upconv1 = nn.ConvTranspose2d(128, 64, kernel_size=2, stride=2)
self.att1 = AttentionBlock(F_g=64, F_l=64, F_int=32)
self.decoder1 = self.conv_block(128, 64)
self.final_conv = nn.Conv2d(64, 3, kernel_size=1, stride=1, padding=0)
self.sigmoid = nn.Sigmoid()
def conv_block(self, in_channels, out_channels):
return nn.Sequential(
nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1),
nn.BatchNorm2d(out_channels),
nn.ReLU(inplace=True),
nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1),
nn.BatchNorm2d(out_channels),
nn.ReLU(inplace=True)
)
def forward(self, x):
e1 = self.encoder1(x)
e2 = self.encoder2(F.max_pool2d(e1, 2))
e3 = self.encoder3(F.max_pool2d(e2, 2))
e4 = self.encoder4(F.max_pool2d(e3, 2))
b = self.bottleneck(F.max_pool2d(e4, 2))
d4 = self.upconv4(b)
e4 = self.att4(g=d4, x=e4)
d4 = torch.cat((e4, d4), dim=1)
d4 = self.decoder4(d4)
d3 = self.upconv3(d4)
e3 = self.att3(g=d3, x=e3)
d3 = torch.cat((e3, d3), dim=1)
d3 = self.decoder3(d3)
d2 = self.upconv2(d3)
e2 = self.att2(g=d2, x=e2)
d2 = torch.cat((e2, d2), dim=1)
d2 = self.decoder2(d2)
d1 = self.upconv1(d2)
e1 = self.att1(g=d1, x=e1)
d1 = torch.cat((e1, d1), dim=1)
d1 = self.decoder1(d1)
out = self.final_conv(d1)
out = self.sigmoid(out)
return out
# 创建模型实例
unet = UNet()
attention_unet = AttentionUNet()
# 创建虚拟输入张量
x = torch.randn(1, 3, 256, 256)
# 绘制 UNet 计算图
y = unet(x)
make_dot(y, params=dict(list(unet.named_parameters()))).render("unet", format="png")
# 绘制 AttentionUNet 计算图
y = attention_unet(x)
make_dot(y, params=dict(list(attention_unet.named_parameters()))).render("attention_unet", format="png")
```
这段代码将生成两个 PNG 文件:`unet.png` 和 `attention_unet.png`,分别表示 `UNet` 和 `AttentionUNet` 模型的计算图。你可以通过查看这些文件来了解模型的结构和数据流。
阅读全文