写一个unet网络的py文件,包含:2个类,2个方法,使用nn.module 和 nn.Sequential 和for 循环
时间: 2024-04-07 08:30:23 浏览: 101
下面是一个示例的PyTorch Unet网络的Python文件示例,包含两个类和两个方法,使用`nn.Module`和`nn.Sequential`以及for循环:
```python
import torch
import torch.nn as nn
class DoubleConv(nn.Module):
def __init__(self, in_channels, out_channels):
super(DoubleConv, self).__init__()
self.conv = nn.Sequential(
nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1),
nn.BatchNorm2d(out_channels),
nn.ReLU(inplace=True),
nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1),
nn.BatchNorm2d(out_channels),
nn.ReLU(inplace=True)
)
def forward(self, x):
return self.conv(x)
class UNet(nn.Module):
def __init__(self, in_channels, out_channels):
super(UNet, self).__init__()
self.down1 = DoubleConv(in_channels, 64)
self.down2 = DoubleConv(64, 128)
self.up1 = nn.ConvTranspose2d(128, 64, kernel_size=2, stride=2)
self.up2 = nn.ConvTranspose2d(64, out_channels, kernel_size=2, stride=2)
self.out_conv = nn.Conv2d(out_channels, out_channels, kernel_size=1)
def forward(self, x):
x1 = self.down1(x)
x2 = self.down2(x1)
x = self.up1(x2)
x = self.up2(x + x1)
x = self.out_conv(x)
return x
# 测试代码
if __name__ == "__main__":
# 创建一个输入张量
input_tensor = torch.randn((1, 3, 256, 256))
# 创建UNet模型实例
model = UNet(in_channels=3, out_channels=3)
# 将输入张量传递给模型进行前向传播
output_tensor = model(input_tensor)
# 打印输出张量的形状
print("Output tensor shape:", output_tensor.shape)
```
在这个示例中,有两个类:`DoubleConv`和`UNet`。`DoubleConv`类定义了一个双重卷积层,`UNet`类定义了一个包含两个下采样(编码器)和两个上采样(解码器)的UNet网络。这些类都继承自`nn.Module`,并通过`nn.Sequential`定义了模型的层次结构。在`UNet`类的前向传播方法中,通过for循环实现了层次之间的连接。
注意:这只是一个简化的示例,用于说明UNet网络的基本结构。你可以根据需要修改和扩展这些类和方法。
阅读全文