只保留一层unet 把代码完整的复制给我
时间: 2024-11-23 10:49:54 浏览: 21
以下是简化后的 `UNet` 模型,仅保留了一层编码器和解码器:
```python
import torch
import torch.nn as nn
import torch.nn.functional as F
class UNet(nn.Module):
def __init__(self):
super(UNet, self).__init__()
self.encoder = nn.Sequential(
nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1),
nn.ReLU(),
nn.MaxPool2d(kernel_size=2, stride=2)
)
self.decoder = nn.Sequential(
nn.ConvTranspose2d(64, 3, kernel_size=2, stride=2),
nn.Sigmoid()
)
def forward(self, x):
x = self.encoder(x)
x = self.decoder(x)
return x
```
这个简化版本的 `UNet` 只有一层编码器和解码器,适用于简单的图像处理任务。如果你需要进一步调整或扩展,请告知具体需求。
相关问题
只保留一层unet 把完整代码复制粘贴过来
以下是简化后的代码,仅保留了一层UNet,并去除了其他部分:
```python
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
from PIL import Image
import os
import argparse
import glob
import matplotlib.pyplot as plt
from tqdm import tqdm
import numpy as np
class UNet(nn.Module):
def __init__(self):
super(UNet, self).__init__()
self.encoder = nn.Sequential(
nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1),
nn.ReLU(),
nn.MaxPool2d(kernel_size=2, stride=2),
nn.Conv2d(64, 128, kernel_size=3, padding=1),
nn.ReLU(),
nn.MaxPool2d(2),
)
self.decoder = nn.Sequential(
nn.ConvTranspose2d(128, 64, kernel_size=2, stride=2),
nn.ReLU(),
nn.ConvTranspose2d(64, 3, kernel_size=2, stride=2),
nn.Sigmoid(),
)
def forward(self, x):
x = self.encoder(x)
x = self.decoder(x)
return x
class ColorblindDataset(Dataset):
def __init__(self, image_dir, mode='train', transform=None):
self.image_dir = image_dir
self.mode = mode
self.transform = transform
self.normal_images = glob.glob(image_dir + '/' + mode + '/' + 'origin_image' + '/*')
self.blind_images = glob.glob(image_dir + '/' + mode + '/' + 'recolor_image' + '/' + '*Protanopia*')
self.normal_images.sort()
self.blind_images.sort()
self.image_pair = []
for index, image in enumerate(self.normal_images):
self.image_pair.append([self.normal_images[index], self.blind_images[index]])
def __len__(self):
return len(self.normal_images)
def __getitem__(self, idx):
normal_path, blind_path = self.image_pair[idx]
normal_image = Image.open(normal_path).convert('RGB')
blind_image = Image.open(blind_path).convert('RGB')
if self.transform:
normal_image = self.transform(normal_image)
blind_image = self.transform(blind_image)
return normal_image, blind_image
def train_unet_model(unet_model, dataloader, criterion, optimizer, device):
unet_model.train()
running_loss = 0.0
for inputs, targets in tqdm(dataloader, desc="Training"):
inputs, targets = inputs.to(device), targets.to(device)
optimizer.zero_grad()
outputs = unet_model(inputs)
loss = criterion(outputs, targets)
loss.backward()
optimizer.step()
running_loss += loss.item()
epoch_loss = running_loss / len(dataloader)
return epoch_loss
def validate_unet_model(unet_model, dataloader, criterion, device):
unet_model.eval()
running_loss = 0.0
with torch.no_grad():
for inputs, targets in tqdm(dataloader, desc="Validation"):
inputs, targets = inputs.to(device), targets.to(device)
outputs = unet_model(inputs)
loss = criterion(outputs, targets)
running_loss += loss.item()
epoch_loss = running_loss / len(dataloader)
return epoch_loss
def visualize_results(model, dataloader, device, num_images=10, save_path='./results'):
model.eval()
inputs, _ = next(iter(dataloader))
inputs = inputs.to(device)
with torch.no_grad():
outputs = model(inputs)
inputs = inputs.cpu().numpy()
outputs = outputs.cpu().numpy()
if not os.path.exists(save_path):
os.makedirs(save_path)
plt.figure(figsize=(20, 10))
for i in range(num_images):
plt.subplot(2, num_images, i + 1)
plt.imshow(inputs[i].transpose(1, 2, 0))
plt.title("Original")
plt.axis('off')
plt.subplot(2, num_images, i + 1 + num_images)
plt.imshow(outputs[i].transpose(1, 2, 0))
plt.title("Corrected")
plt.axis('off')
plt.savefig(f'{save_path}_visualization.png')
plt.show()
def plot_and_save_losses(train_losses, val_losses, epoch, path='./loss_plots'):
if not os.path.exists(path):
os.makedirs(path)
epochs = np.arange(1, epoch + 2)
plt.figure(figsize=(10, 5))
plt.plot(epochs, train_losses, label='Training Loss')
plt.plot(epochs, val_losses, label='Validation Loss')
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.title('Training and Validation Losses Over Epochs')
plt.legend()
plt.savefig(f'{path}_loss_curve.png')
plt.close()
def main(args):
transform = transforms.Compose([
transforms.ToTensor(),
transforms.Resize((256, 256)),
])
train_dataset = ColorblindDataset(args.dataset_dir, mode='train', transform=transform)
val_dataset = ColorblindDataset(args.dataset_dir, mode='val', transform=transform)
train_loader = DataLoader(train_dataset, batch_size=args.batch_size, num_workers=4, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=args.batch_size, num_workers=4, shuffle=False)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
unet_model = UNet().to(device)
if os.path.exists(args.model_weight_path):
print(f"Loading weights from {args.model_weight_path}")
unet_model.load_state_dict(torch.load(args.model_weight_path))
criterion = nn.MSELoss()
optimizer = optim.Adam(unet_model.parameters(), lr=args.learning_rate)
train_losses = []
val_losses = []
for epoch in range(args.num_epochs):
train_loss = train_unet_model(unet_model, train_loader, criterion, optimizer, device)
val_loss = validate_unet_model(unet_model, val_loader, criterion, device)
train_losses.append(train_loss)
val_losses.append(val_loss)
print(f'Epoch {epoch + 1}, Training Loss: {train_loss:.4f}, Validation Loss: {val_loss:.4f}')
plot_and_save_losses(train_losses, val_losses, epoch)
visualize_results(unet_model, train_loader, device, save_path='./results/train')
visualize_results(unet_model, val_loader, device, save_path='./results/val')
torch.save(unet_model.state_dict(), args.model_save_path)
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="UNet model")
parser.add_argument('--dataset_dir', type=str, default='./dataset', help='数据集目录路径')
parser.add_argument('--batch_size', type=int, default=10, help='训练批次大小')
parser.add_argument('--learning_rate', type=float, default=0.01, help='优化器学习率')
parser.add_argument('--num_epochs', type=int, default=300, help='训练的轮数')
parser.add_argument('--model_save_path', type=str, default='./model_weights/unet_model.pth', help='保存训练好的模型的路径')
parser.add_argument('--model_weight_path', type=str, default='', help='训练好的模型的路径')
args = parser.parse_args()
main(args)
```
### 主要修改点:
1. **删除了 `AttentionBlock` 和 `AttentionUNet` 类**,因为这些类在单层UNet中不需要。
2. **删除了 `CorrectionUNet` 类**,因为它是基于 `AttentionUNet` 的复杂版本。
3. **修改了 `main` 函数中的模型加载和保存路径**,以适应新的模型结构。
4. **删除了 `train_correction_model` 和 `validate_correction_model` 函数**,因为它们是针对 `CorrectionUNet` 的特定实现。
5. **添加了 `train_unet_model` 和 `validate_unet_model` 函数**,用于训练和验证单层UNet模型。
6. **更新了 `visualize_results` 函数**,使其适用于单层UNet模型的输出。
这样,代码就简化为一个基本的单层UNet模型,可以用于训练和验证。
只保留两层unet 把完整代码复制粘贴过来
以下是简化后的 `UNet` 模型,仅保留了两层编码器和解码器:
```python
import torch
import torch.nn as nn
import torch.nn.functional as F
class UNet(nn.Module):
def __init__(self):
super(UNet, self).__init__()
self.encoder1 = self.conv_block(3, 64)
self.encoder2 = self.conv_block(64, 128)
self.decoder1 = nn.Sequential(
nn.ConvTranspose2d(128, 64, kernel_size=2, stride=2),
nn.ReLU(),
)
self.decoder2 = nn.Sequential(
nn.ConvTranspose2d(128, 3, kernel_size=2, stride=2),
nn.Sigmoid(),
)
def conv_block(self, in_channels, out_channels):
return nn.Sequential(
nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1),
nn.BatchNorm2d(out_channels),
nn.ReLU(inplace=True),
nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1),
nn.BatchNorm2d(out_channels),
nn.ReLU(inplace=True)
)
def forward(self, x):
e1 = self.encoder1(x)
e2 = self.encoder2(F.max_pool2d(e1, 2))
d1 = self.decoder1(e2)
d1 = torch.cat((e1, d1), dim=1)
d2 = self.decoder2(d1)
return d2
```
### 解释
1. **编码器部分**:
- `encoder1`: 将输入图像从3个通道转换为64个通道。
- `encoder2`: 将64个通道的特征图通过最大池化操作降采样,并将其通道数增加到128。
2. **解码器部分**:
- `decoder1`: 使用转置卷积(上采样)将128个通道的特征图恢复到64个通道,并使用ReLU激活函数。
- `decoder2`: 再次使用转置卷积将64个通道的特征图恢复到3个通道,并使用Sigmoid激活函数将输出限制在0到1之间。
3. **前向传播**:
- 输入图像首先通过两个编码器层进行下采样。
- 然后通过两个解码器层进行上采样。
- 最后,将解码器的输出与相应的编码器特征图拼接在一起,以保留更多的细节信息。
希望这个简化版本的 `UNet` 模型能满足你的需求。如果有任何进一步的问题或需要调整的地方,请告诉我!
阅读全文