三层unet 不加注意力机制 把完整代码全部复制粘贴过来
时间: 2024-11-25 14:16:15 浏览: 8
以下是不带注意力机制的三层UNet模型的完整代码:
```python
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
from PIL import Image
import os
import argparse
import glob
import matplotlib.pyplot as plt
from tqdm import tqdm
import numpy as np
class UNet(nn.Module):
def __init__(self):
super(UNet, self).__init__()
self.encoder1 = self.conv_block(3, 64)
self.encoder2 = self.conv_block(64, 128)
self.encoder3 = self.conv_block(128, 256)
self.bottleneck = self.conv_block(256, 512)
self.upconv3 = nn.ConvTranspose2d(512, 256, kernel_size=2, stride=2)
self.decoder3 = self.conv_block(512, 256)
self.upconv2 = nn.ConvTranspose2d(256, 128, kernel_size=2, stride=2)
self.decoder2 = self.conv_block(256, 128)
self.upconv1 = nn.ConvTranspose2d(128, 64, kernel_size=2, stride=2)
self.decoder1 = self.conv_block(128, 64)
self.final_conv = nn.Conv2d(64, 3, kernel_size=1, stride=1, padding=0)
self.sigmoid = nn.Sigmoid()
def conv_block(self, in_channels, out_channels):
return nn.Sequential(
nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1),
nn.BatchNorm2d(out_channels),
nn.ReLU(inplace=True),
nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1),
nn.BatchNorm2d(out_channels),
nn.ReLU(inplace=True)
)
def forward(self, x):
e1 = self.encoder1(x)
e2 = self.encoder2(F.max_pool2d(e1, 2))
e3 = self.encoder3(F.max_pool2d(e2, 2))
b = self.bottleneck(F.max_pool2d(e3, 2))
d3 = self.upconv3(b)
d3 = torch.cat((e3, d3), dim=1)
d3 = self.decoder3(d3)
d2 = self.upconv2(d3)
d2 = torch.cat((e2, d2), dim=1)
d2 = self.decoder2(d2)
d1 = self.upconv1(d2)
d1 = torch.cat((e1, d1), dim=1)
d1 = self.decoder1(d1)
out = self.final_conv(d1)
out = self.sigmoid(out)
return out
class ColorblindDataset(Dataset):
def __init__(self, image_dir, mode='train', transform=None):
self.image_dir = image_dir
self.mode = mode
self.transform = transform
self.normal_images = glob.glob(image_dir + '/' + mode + '/' + 'origin_image' + '/*')
self.recolor_images = glob.glob(image_dir + '/' + mode + '/' + 'recolor_image' + '/' + '*Protanopia*')
self.correct_images = glob.glob(image_dir + '/' + mode + '/' + 'correct_image' + '/*')
self.normal_images.sort()
self.recolor_images.sort()
self.correct_images.sort()
self.image_pair = []
for index, image in enumerate(self.normal_images):
self.image_pair.append([self.recolor_images[index], self.normal_images[index]])
def __len__(self):
return len(self.image_pair)
def __getitem__(self, idx):
recolor_path, normal_path = self.image_pair[idx]
recolor_image = Image.open(recolor_path).convert('RGB')
normal_image = Image.open(normal_path).convert('RGB')
if self.transform:
recolor_image = self.transform(recolor_image)
normal_image = self.transform(normal_image)
return recolor_image, normal_image
def compute_iou(outputs, targets, threshold=0.5):
outputs = (outputs > threshold).float()
targets = (targets > threshold).float()
intersection = (outputs * targets).sum(dim=(1, 2, 3))
union = outputs.sum(dim=(1, 2, 3)) + targets.sum(dim=(1, 2, 3)) - intersection
iou = (intersection + 1e-6) / (union + 1e-6)
return iou.mean().item()
from skimage.metrics import peak_signal_noise_ratio as psnr_metric
from skimage.metrics import structural_similarity as ssim_metric
def compute_psnr(outputs, targets):
outputs = outputs.cpu().detach().numpy()
targets = targets.cpu().detach().numpy()
psnr = 0
for i in range(outputs.shape[0]):
psnr += psnr_metric(targets[i], outputs[i], data_range=1.0)
return psnr / outputs.shape[0]
def compute_ssim(outputs, targets):
outputs = outputs.cpu().detach().numpy()
targets = targets.cpu().detach().numpy()
ssim = 0
for i in range(outputs.shape[0]):
output_img = outputs[i].transpose(1, 2, 0)
target_img = targets[i].transpose(1, 2, 0)
H, W, _ = output_img.shape
min_dim = min(H, W)
win_size = min(7, min_dim if min_dim % 2 == 1 else min_dim - 1)
win_size = max(win_size, 3)
ssim += ssim_metric(target_img, output_img, data_range=1.0, channel_axis=-1, win_size=win_size)
return ssim / outputs.shape[0]
def wasserstein_loss(pred, target):
return torch.mean(pred * target)
from torch.autograd import grad
def compute_gradient_penalty(discriminator, real_samples, fake_samples, device):
alpha = torch.rand(real_samples.size(0), 1, 1, 1, device=device)
interpolates = (alpha * real_samples + ((1 - alpha) * fake_samples)).requires_grad_(True)
d_interpolates = discriminator(interpolates)
fake = torch.ones(real_samples.size(0), device=device)
gradients = grad(outputs=d_interpolates, inputs=interpolates, grad_outputs=fake, create_graph=True, retain_graph=True, only_inputs=True)[0]
gradients = gradients.view(gradients.size(0), -1)
gradient_penalty = ((gradients.norm(2, dim=1) - 1) ** 2).mean()
return gradient_penalty
def train_correction_model(generator, discriminator, dataloader, optimizer_G, optimizer_D, device, lambda_gp, lambda_pixel, n_critic):
generator.train()
discriminator.train()
running_g_loss = 0.0
running_d_loss = 0.0
running_iou = 0.0
running_psnr = 0.0
running_ssim = 0.0
for batch_idx, (inputs, targets) in enumerate(tqdm(dataloader, desc="Training")):
inputs = inputs.to(device)
targets = targets.to(device)
# Train Discriminator
optimizer_D.zero_grad()
corrected_images = generator(inputs)
real_validity = discriminator(targets)
fake_validity = discriminator(corrected_images.detach())
gp = compute_gradient_penalty(discriminator, targets.data, corrected_images.data, device)
d_loss = -torch.mean(real_validity) + torch.mean(fake_validity) + lambda_gp * gp
d_loss.backward()
optimizer_D.step()
# Train Generator
if batch_idx % n_critic == 0:
optimizer_G.zero_grad()
corrected_images = generator(inputs)
fake_validity = discriminator(corrected_images)
g_adv_loss = -torch.mean(fake_validity)
pixelwise_loss = nn.L1Loss()
g_pixel_loss = pixelwise_loss(corrected_images, targets)
g_loss = g_adv_loss + lambda_pixel * g_pixel_loss
g_loss.backward()
optimizer_G.step()
else:
g_loss = torch.tensor(0.0)
running_g_loss += g_loss.item()
running_d_loss += d_loss.item()
iou = compute_iou(corrected_images, targets)
psnr = compute_psnr(corrected_images, targets)
ssim = compute_ssim(corrected_images, targets)
running_iou += iou
running_psnr += psnr
running_ssim += ssim
epoch_g_loss = running_g_loss / len(dataloader)
epoch_d_loss = running_d_loss / len(dataloader)
epoch_iou = running_iou / len(dataloader)
epoch_psnr = running_psnr / len(dataloader)
epoch_ssim = running_ssim / len(dataloader)
return epoch_g_loss, epoch_d_loss, epoch_iou, epoch_psnr, epoch_ssim
def validate_correction_model(generator, discriminator, dataloader, device, lambda_gp):
generator.eval()
discriminator.eval()
running_g_loss = 0.0
running_d_loss = 0.0
running_iou = 0.0
running_psnr = 0.0
running_ssim = 0.0
with torch.no_grad():
for inputs, targets in tqdm(dataloader, desc="Validation"):
inputs = inputs.to(device)
targets = targets.to(device)
corrected_images = generator(inputs)
real_validity = discriminator(targets)
fake_validity = discriminator(corrected_images)
d_loss = -torch.mean(real_validity) + torch.mean(fake_validity)
g_adv_loss = -torch.mean(fake_validity)
g_loss = g_adv_loss
running_g_loss += g_loss.item()
running_d_loss += d_loss.item()
iou = compute_iou(corrected_images, targets)
psnr = compute_psnr(corrected_images, targets)
ssim = compute_ssim(corrected_images, targets)
running_iou += iou
running_psnr += psnr
running_ssim += ssim
epoch_g_loss = running_g_loss / len(dataloader)
epoch_d_loss = running_d_loss / len(dataloader)
epoch_iou = running_iou / len(dataloader)
epoch_psnr = running_psnr / len(dataloader)
epoch_ssim = running_ssim / len(dataloader)
return epoch_g_loss, epoch_d_loss, epoch_iou, epoch_psnr, epoch_ssim
def visualize_results(generator, dataloader, device, num_images=10, save_path='./results'):
generator.eval()
inputs, targets = next(iter(dataloader))
inputs = inputs.to(device)
targets = targets.to(device)
with torch.no_grad():
corrected_images = generator(inputs)
inputs = inputs.cpu().numpy()
targets = targets.cpu().numpy()
corrected_images = corrected_images.cpu().numpy()
if not os.path.exists(save_path):
os.makedirs(save_path)
plt.figure(figsize=(20, 10))
for i in range(num_images):
plt.subplot(3, num_images, i + 1)
plt.imshow(targets[i].transpose(1, 2, 0))
plt.title("Original")
plt.axis('off')
plt.subplot(3, num_images, i + 1 + num_images)
plt.imshow(inputs[i].transpose(1, 2, 0))
plt.title("Simulated Colorblind")
plt.axis('off')
plt.subplot(3, num_images, i + 1 + 2 * num_images)
plt.imshow(corrected_images[i].transpose(1, 2, 0))
plt.title("Corrected")
plt.axis('off')
plt.tight_layout()
plt.savefig(f'{save_path}_visualization.png')
plt.show()
def plot_and_save_metrics(train_metrics, val_metrics, epoch, path='./metrics_plots'):
if not os.path.exists(path):
os.makedirs(path)
epochs = np.arange(1, epoch + 1)
train_g_losses, train_d_losses, train_ious, train_psnrs, train_ssims = zip(*train_metrics)
val_g_losses, val_d_losses, val_ious, val_psnrs, val_ssims = zip(*val_metrics)
plt.figure()
plt.plot(epochs, train_g_losses, label='Training Generator Loss')
plt.plot(epochs, val_g_losses, label='Validation Generator Loss')
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.title('Generator Loss over Epochs')
plt.legend()
plt.savefig(f'{path}/generator_loss.png')
plt.close()
plt.figure()
plt.plot(epochs, train_d_losses, label='Training Discriminator Loss')
plt.plot(epochs, val_d_losses, label='Validation Discriminator Loss')
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.title('Discriminator Loss over Epochs')
plt.legend()
plt.savefig(f'{path}/discriminator_loss.png')
plt.close()
plt.figure()
plt.plot(epochs, train_ious, label='Training IoU')
plt.plot(epochs, val_ious, label='Validation IoU')
plt.xlabel('Epoch')
plt.ylabel('IoU')
plt.title('IoU over Epochs')
plt.legend()
plt.savefig(f'{path}/iou.png')
plt.close()
plt.figure()
plt.plot(epochs, train_psnrs, label='Training PSNR')
plt.plot(epochs, val_psnrs, label='Validation PSNR')
plt.xlabel('Epoch')
plt.ylabel('PSNR')
plt.title('PSNR over Epochs')
plt.legend()
plt.savefig(f'{path}/psnr.png')
plt.close()
plt.figure()
plt.plot(epochs, train_ssims, label='Training SSIM')
plt.plot(epochs, val_ssims, label='Validation SSIM')
plt.xlabel('Epoch')
plt.ylabel('SSIM')
plt.title('SSIM over Epochs')
plt.legend()
plt.savefig(f'{path}/ssim.png')
plt.close()
def main(args):
transform = transforms.Compose([
transforms.Resize((256, 256)),
transforms.ToTensor(),
])
train_dataset = ColorblindDataset(args.dataset_dir, mode='train', transform=transform)
val_dataset = ColorblindDataset(args.dataset_dir, mode='val', transform=transform)
train_loader = DataLoader(train_dataset, batch_size=args.batch_size, num_workers=4, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=args.batch_size, num_workers=4, shuffle=False)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
generator = UNet().to(device)
discriminator = Discriminator().to(device)
if args.generator_model_weight_path:
print(f"Loading generator weights from {args.generator_model_weight_path}")
generator.load_state_dict(torch.load(args.generator_model_weight_path))
阅读全文