U-Net神经网络代码
时间: 2024-05-26 07:06:58 浏览: 101
U-Net是一种常用于图像分割任务的神经网络模型,它具有编码器-解码器结构和跳跃连接,能够有效地提取图像特征并保留细节信息。下面是一个简单的U-Net神经网络的代码示例:
```python
import torch
import torch.nn as nn
# 定义U-Net模型
class UNet(nn.Module):
def __init__(self):
super(UNet, self).__init__()
# 编码器部分
self.encoder = nn.Sequential(
nn.Conv2d(3, 64, kernel_size=3, padding=1),
nn.ReLU(inplace=True),
nn.Conv2d(64, 64, kernel_size=3, padding=1),
nn.ReLU(inplace=True),
nn.MaxPool2d(kernel_size=2, stride=2)
)
# 解码器部分
self.decoder = nn.Sequential(
nn.Conv2d(64, 64, kernel_size=3, padding=1),
nn.ReLU(inplace=True),
nn.Conv2d(64, 64, kernel_size=3, padding=1),
nn.ReLU(inplace=True),
nn.ConvTranspose2d(64, 3, kernel_size=2, stride=2)
)
def forward(self, x):
# 编码器部分
x1 = self.encoder(x)
# 解码器部分
x2 = self.decoder(x1)
return x2
# 创建U-Net模型实例
model = UNet()
# 打印模型结构
print(model)
```
这段代码定义了一个简单的U-Net模型,包括编码器和解码器部分。编码器使用卷积层和池化层来提取图像特征,解码器使用卷积层和反卷积层来进行特征重建。在`forward`方法中,输入数据经过编码器和解码器后得到输出。
阅读全文