看我的文件 在源代码基础上改成一层unet 改完后把完整的代码复制粘贴给我
时间: 2024-11-23 12:49:53 浏览: 20
好的,我将修改你的 `UNet` 模型,使其只有一层编码器和解码器,并保留其他部分不变。以下是修改后的完整代码:
```python
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
from PIL import Image
import os
import argparse
import glob
import matplotlib.pyplot as plt
from tqdm import tqdm
import numpy as np
class UNet(nn.Module):
def __init__(self):
super(UNet, self).__init__()
self.encoder = nn.Sequential(
nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1),
nn.ReLU(),
nn.MaxPool2d(kernel_size=2, stride=2)
)
self.decoder = nn.Sequential(
nn.ConvTranspose2d(64, 3, kernel_size=2, stride=2),
nn.Sigmoid()
)
def forward(self, x):
x = self.encoder(x)
x = self.decoder(x)
return x
class AttentionBlock(nn.Module):
def __init__(self, F_g, F_l, F_int):
super(AttentionBlock, self).__init__()
self.W_g = nn.Sequential(
nn.Conv2d(F_g, F_int, kernel_size=1, stride=1, padding=0, bias=True),
nn.BatchNorm2d(F_int)
)
self.W_x = nn.Sequential(
nn.Conv2d(F_l, F_int, kernel_size=1, stride=1, padding=0, bias=True),
nn.BatchNorm2d(F_int)
)
self.psi = nn.Sequential(
nn.Conv2d(F_int, 1, kernel_size=1, stride=1, padding=0, bias=True),
nn.BatchNorm2d(1),
nn.Sigmoid()
)
self.relu = nn.ReLU(inplace=True)
def forward(self, g, x):
g1 = self.W_g(g)
x1 = self.W_x(x)
psi = self.relu(g1 + x1)
psi = self.psi(psi)
return x * psi
class AttentionUNet(nn.Module):
def __init__(self):
super(AttentionUNet, self).__init__()
self.encoder1 = self.conv_block(3, 64)
self.bottleneck = self.conv_block(64, 128)
self.upconv1 = nn.ConvTranspose2d(128, 64, kernel_size=2, stride=2)
self.att1 = AttentionBlock(F_g=64, F_l=64, F_int=32)
self.decoder1 = self.conv_block(128, 64)
self.final_conv = nn.Conv2d(64, 3, kernel_size=1, stride=1, padding=0)
self.sigmoid = nn.Sigmoid()
def conv_block(self, in_channels, out_channels):
return nn.Sequential(
nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1),
nn.BatchNorm2d(out_channels),
nn.ReLU(inplace=True),
nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1),
nn.BatchNorm2d(out_channels),
nn.ReLU(inplace=True)
)
def forward(self, x):
# Encoding
e1 = self.encoder1(x)
b = self.bottleneck(F.max_pool2d(e1, 2))
# Decoding + Attention Gate
d1 = self.upconv1(b)
e1 = self.att1(g=d1, x=e1)
d1 = torch.cat((e1, d1), dim=1)
d1 = self.decoder1(d1)
out = self.final_conv(d1)
out = self.sigmoid(out)
return out
class ColorblindDataset(Dataset):
def __init__(self, image_dir, mode='train', transform=None):
self.image_dir = image_dir
self.mode = mode
self.transform = transform
self.normal_images = glob.glob(image_dir + '/' + mode + '/' + 'origin_image' + '/*')
self.recolor_images = glob.glob(image_dir + '/' + mode + '/' + 'recolor_image' + '/' + '*Protanopia*')
self.correct_images = glob.glob(image_dir + '/' + mode + '/' + 'correct_image' + '/*')
self.normal_images.sort()
self.recolor_images.sort()
self.correct_images.sort()
self.image_pair = []
for index, image in enumerate(self.normal_images):
self.image_pair.append([self.normal_images[index], self.recolor_images[index]])
self.image_pair.append([self.correct_images[index], self.normal_images[index]])
def __len__(self):
return len(self.normal_images)
def __getitem__(self, idx):
normal_path, recolor_path = self.image_pair[idx]
normal_image = Image.open(normal_path).convert('RGB')
recolor_image = Image.open(recolor_path).convert('RGB')
if self.transform:
normal_image = self.transform(normal_image)
recolor_image = self.transform(recolor_image)
return normal_image, recolor_image
def train_one_epoch(model, dataloader, criterion, optimizer, device):
model.train()
running_loss = 0.0
for inputs, targets in tqdm(dataloader, desc="Training"):
inputs, targets = inputs.to(device), targets.to(device)
optimizer.zero_grad()
outputs = model(inputs)
loss = criterion(outputs, targets)
loss.backward()
optimizer.step()
running_loss += loss.item()
epoch_loss = running_loss / len(dataloader)
return epoch_loss
def validate(model, dataloader, criterion, device):
model.eval()
val_loss = 0.0
with torch.no_grad():
for inputs, targets in tqdm(dataloader, desc="Validation"):
inputs, targets = inputs.to(device), targets.to(device)
outputs = model(inputs)
loss = criterion(outputs, targets)
val_loss += loss.item()
val_loss /= len(dataloader)
return val_loss
def visualize_results(model, dataloader, device, num_images=10):
model.eval()
inputs, targets = next(iter(dataloader))
inputs, targets = inputs.to(device), targets.to(device)
with torch.no_grad():
outputs = model(inputs)
outputs = outputs.cpu().numpy()
inputs = inputs.cpu().numpy()
targets = targets.cpu().numpy()
plt.figure(figsize=(15, 10))
for i in range(num_images):
plt.subplot(3, num_images, i + 1)
plt.imshow(inputs[i].transpose(1, 2, 0))
plt.title("Original")
plt.axis('off')
plt.subplot(3, num_images, i + 1 + num_images)
plt.imshow(targets[i].transpose(1, 2, 0))
plt.title("Colorblind")
plt.axis('off')
plt.subplot(3, num_images, i + 1 + 2 * num_images)
plt.imshow(outputs[i].transpose(1, 2, 0))
plt.title("Reconstructed")
plt.axis('off')
plt.show()
def plot_and_save_losses(train_losses, val_losses, epoch, path='./loss_plots'):
if not os.path.exists(path):
os.makedirs(path)
epochs = np.arange(1, epoch+2)
plt.figure(figsize=(10, 5))
plt.plot(epochs, train_losses, label='Training Loss')
plt.plot(epochs, val_losses, label='Validation Loss')
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.title('Training and Validation Losses')
plt.legend()
plt.savefig(f'{path}/loss_epoch_{epoch+1}.png')
plt.close()
def main(args):
# Data transforms
transform = transforms.Compose([
transforms.ToTensor(),
transforms.Resize((256, 256)),
])
# Datasets and Dataloaders
train_dataset = ColorblindDataset(args.dataset_dir, mode='train', transform=transform)
val_dataset = ColorblindDataset(args.dataset_dir, mode='val', transform=transform)
train_loader = DataLoader(train_dataset, batch_size=args.batch_size, num_workers=4, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=args.batch_size, num_workers=4, shuffle=False)
# Model, Loss, Optimizer
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = UNet().to(device) # 使用简化版的UNet
if args.model_pretrained_path:
model.load_state_dict(torch.load(args.model_pretrained_path))
print("Successfully load past pretrained weights!!")
criterion = nn.MSELoss()
optimizer = optim.Adam(model.parameters(), lr=args.learning_rate)
train_losses = []
val_losses = []
# Training and validation loop
for epoch in range(args.num_epochs):
train_loss = train_one_epoch(model, train_loader, criterion, optimizer, device)
val_loss = validate(model, val_loader, criterion, device)
train_losses.append(train_loss)
val_losses.append(val_loss)
print(f'Epoch {epoch + 1}, Training Loss: {train_loss:.4f}, Validation Loss: {val_loss:.4f}')
plot_and_save_losses(train_losses, val_losses, epoch)
visualize_results(model, val_loader, device)
# Save the model
torch.save(model.state_dict(), args.model_save_path)
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="UNet Colorblind Image Reconstruction")
parser.add_argument('--dataset_dir', type=str, default='./dataset', help='Path to the dataset directory')
parser.add_argument('--batch_size', type=int, default=16, help='Batch size for training and validation')
parser.add_argument('--learning_rate', type=float, default=0.01, help='Learning rate for the optimizer')
parser.add_argument('--num_epochs', type=int, default=100, help='Number of training epochs')
parser.add_argument('--model_save_path', type=str, default='./model_weights/color_blind_model.pth', help='Path to save the trained model')
parser.add_argument('--model_pretrained_path', type=str, default='./model_weights/color_blind_model.pth', help='训练好的色盲模拟器模型路径')
args = parser.parse_args()
main(args)
```
在这个版本中,`UNet` 类被简化为只有一层编码器和解码器。其余部分保持不变,包括数据集加载、训练循环和验证函数等。希望这能满足你的需求。如果有任何进一步的问题,请随时告诉我!
阅读全文