只保留两层unet 把完整代码复制粘贴过来
时间: 2024-11-23 12:49:22 浏览: 9
以下是简化后的 `UNet` 模型,仅保留了两层编码器和解码器:
```python
import torch
import torch.nn as nn
import torch.nn.functional as F
class UNet(nn.Module):
def __init__(self):
super(UNet, self).__init__()
self.encoder1 = self.conv_block(3, 64)
self.encoder2 = self.conv_block(64, 128)
self.decoder1 = nn.Sequential(
nn.ConvTranspose2d(128, 64, kernel_size=2, stride=2),
nn.ReLU(),
)
self.decoder2 = nn.Sequential(
nn.ConvTranspose2d(128, 3, kernel_size=2, stride=2),
nn.Sigmoid(),
)
def conv_block(self, in_channels, out_channels):
return nn.Sequential(
nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1),
nn.BatchNorm2d(out_channels),
nn.ReLU(inplace=True),
nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1),
nn.BatchNorm2d(out_channels),
nn.ReLU(inplace=True)
)
def forward(self, x):
e1 = self.encoder1(x)
e2 = self.encoder2(F.max_pool2d(e1, 2))
d1 = self.decoder1(e2)
d1 = torch.cat((e1, d1), dim=1)
d2 = self.decoder2(d1)
return d2
```
### 解释
1. **编码器部分**:
- `encoder1`: 将输入图像从3个通道转换为64个通道。
- `encoder2`: 将64个通道的特征图通过最大池化操作降采样,并将其通道数增加到128。
2. **解码器部分**:
- `decoder1`: 使用转置卷积(上采样)将128个通道的特征图恢复到64个通道,并使用ReLU激活函数。
- `decoder2`: 再次使用转置卷积将64个通道的特征图恢复到3个通道,并使用Sigmoid激活函数将输出限制在0到1之间。
3. **前向传播**:
- 输入图像首先通过两个编码器层进行下采样。
- 然后通过两个解码器层进行上采样。
- 最后,将解码器的输出与相应的编码器特征图拼接在一起,以保留更多的细节信息。
希望这个简化版本的 `UNet` 模型能满足你的需求。如果有任何进一步的问题或需要调整的地方,请告诉我!
阅读全文