只保留一层unet 把完整代码复制粘贴过来
时间: 2024-11-22 21:41:55 浏览: 33
以下是简化后的代码,仅保留了一层UNet,并去除了其他部分:
```python
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
from PIL import Image
import os
import argparse
import glob
import matplotlib.pyplot as plt
from tqdm import tqdm
import numpy as np
class UNet(nn.Module):
def __init__(self):
super(UNet, self).__init__()
self.encoder = nn.Sequential(
nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1),
nn.ReLU(),
nn.MaxPool2d(kernel_size=2, stride=2),
nn.Conv2d(64, 128, kernel_size=3, padding=1),
nn.ReLU(),
nn.MaxPool2d(2),
)
self.decoder = nn.Sequential(
nn.ConvTranspose2d(128, 64, kernel_size=2, stride=2),
nn.ReLU(),
nn.ConvTranspose2d(64, 3, kernel_size=2, stride=2),
nn.Sigmoid(),
)
def forward(self, x):
x = self.encoder(x)
x = self.decoder(x)
return x
class ColorblindDataset(Dataset):
def __init__(self, image_dir, mode='train', transform=None):
self.image_dir = image_dir
self.mode = mode
self.transform = transform
self.normal_images = glob.glob(image_dir + '/' + mode + '/' + 'origin_image' + '/*')
self.blind_images = glob.glob(image_dir + '/' + mode + '/' + 'recolor_image' + '/' + '*Protanopia*')
self.normal_images.sort()
self.blind_images.sort()
self.image_pair = []
for index, image in enumerate(self.normal_images):
self.image_pair.append([self.normal_images[index], self.blind_images[index]])
def __len__(self):
return len(self.normal_images)
def __getitem__(self, idx):
normal_path, blind_path = self.image_pair[idx]
normal_image = Image.open(normal_path).convert('RGB')
blind_image = Image.open(blind_path).convert('RGB')
if self.transform:
normal_image = self.transform(normal_image)
blind_image = self.transform(blind_image)
return normal_image, blind_image
def train_unet_model(unet_model, dataloader, criterion, optimizer, device):
unet_model.train()
running_loss = 0.0
for inputs, targets in tqdm(dataloader, desc="Training"):
inputs, targets = inputs.to(device), targets.to(device)
optimizer.zero_grad()
outputs = unet_model(inputs)
loss = criterion(outputs, targets)
loss.backward()
optimizer.step()
running_loss += loss.item()
epoch_loss = running_loss / len(dataloader)
return epoch_loss
def validate_unet_model(unet_model, dataloader, criterion, device):
unet_model.eval()
running_loss = 0.0
with torch.no_grad():
for inputs, targets in tqdm(dataloader, desc="Validation"):
inputs, targets = inputs.to(device), targets.to(device)
outputs = unet_model(inputs)
loss = criterion(outputs, targets)
running_loss += loss.item()
epoch_loss = running_loss / len(dataloader)
return epoch_loss
def visualize_results(model, dataloader, device, num_images=10, save_path='./results'):
model.eval()
inputs, _ = next(iter(dataloader))
inputs = inputs.to(device)
with torch.no_grad():
outputs = model(inputs)
inputs = inputs.cpu().numpy()
outputs = outputs.cpu().numpy()
if not os.path.exists(save_path):
os.makedirs(save_path)
plt.figure(figsize=(20, 10))
for i in range(num_images):
plt.subplot(2, num_images, i + 1)
plt.imshow(inputs[i].transpose(1, 2, 0))
plt.title("Original")
plt.axis('off')
plt.subplot(2, num_images, i + 1 + num_images)
plt.imshow(outputs[i].transpose(1, 2, 0))
plt.title("Corrected")
plt.axis('off')
plt.savefig(f'{save_path}_visualization.png')
plt.show()
def plot_and_save_losses(train_losses, val_losses, epoch, path='./loss_plots'):
if not os.path.exists(path):
os.makedirs(path)
epochs = np.arange(1, epoch + 2)
plt.figure(figsize=(10, 5))
plt.plot(epochs, train_losses, label='Training Loss')
plt.plot(epochs, val_losses, label='Validation Loss')
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.title('Training and Validation Losses Over Epochs')
plt.legend()
plt.savefig(f'{path}_loss_curve.png')
plt.close()
def main(args):
transform = transforms.Compose([
transforms.ToTensor(),
transforms.Resize((256, 256)),
])
train_dataset = ColorblindDataset(args.dataset_dir, mode='train', transform=transform)
val_dataset = ColorblindDataset(args.dataset_dir, mode='val', transform=transform)
train_loader = DataLoader(train_dataset, batch_size=args.batch_size, num_workers=4, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=args.batch_size, num_workers=4, shuffle=False)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
unet_model = UNet().to(device)
if os.path.exists(args.model_weight_path):
print(f"Loading weights from {args.model_weight_path}")
unet_model.load_state_dict(torch.load(args.model_weight_path))
criterion = nn.MSELoss()
optimizer = optim.Adam(unet_model.parameters(), lr=args.learning_rate)
train_losses = []
val_losses = []
for epoch in range(args.num_epochs):
train_loss = train_unet_model(unet_model, train_loader, criterion, optimizer, device)
val_loss = validate_unet_model(unet_model, val_loader, criterion, device)
train_losses.append(train_loss)
val_losses.append(val_loss)
print(f'Epoch {epoch + 1}, Training Loss: {train_loss:.4f}, Validation Loss: {val_loss:.4f}')
plot_and_save_losses(train_losses, val_losses, epoch)
visualize_results(unet_model, train_loader, device, save_path='./results/train')
visualize_results(unet_model, val_loader, device, save_path='./results/val')
torch.save(unet_model.state_dict(), args.model_save_path)
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="UNet model")
parser.add_argument('--dataset_dir', type=str, default='./dataset', help='数据集目录路径')
parser.add_argument('--batch_size', type=int, default=10, help='训练批次大小')
parser.add_argument('--learning_rate', type=float, default=0.01, help='优化器学习率')
parser.add_argument('--num_epochs', type=int, default=300, help='训练的轮数')
parser.add_argument('--model_save_path', type=str, default='./model_weights/unet_model.pth', help='保存训练好的模型的路径')
parser.add_argument('--model_weight_path', type=str, default='', help='训练好的模型的路径')
args = parser.parse_args()
main(args)
```
### 主要修改点:
1. **删除了 `AttentionBlock` 和 `AttentionUNet` 类**,因为这些类在单层UNet中不需要。
2. **删除了 `CorrectionUNet` 类**,因为它是基于 `AttentionUNet` 的复杂版本。
3. **修改了 `main` 函数中的模型加载和保存路径**,以适应新的模型结构。
4. **删除了 `train_correction_model` 和 `validate_correction_model` 函数**,因为它们是针对 `CorrectionUNet` 的特定实现。
5. **添加了 `train_unet_model` 和 `validate_unet_model` 函数**,用于训练和验证单层UNet模型。
6. **更新了 `visualize_results` 函数**,使其适用于单层UNet模型的输出。
这样,代码就简化为一个基本的单层UNet模型,可以用于训练和验证。
阅读全文