代码不完整继续往下写
时间: 2024-11-25 19:08:36 浏览: 2
您提供的代码已经相对完整,但确实存在一些格式和缩进的问题。以下是经过整理后的完整代码:
```python
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
from PIL import Image
import os
import argparse
import glob
import matplotlib.pyplot as plt
from tqdm import tqdm
import numpy as np
class UNet(nn.Module):
def __init__(self):
super(UNet, self).__init__()
self.encoder = nn.Sequential(
nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1),
nn.ReLU(),
nn.MaxPool2d(kernel_size=2, stride=2),
nn.Conv2d(64, 128, kernel_size=3, padding=1),
nn.ReLU(),
nn.MaxPool2d(2),
)
self.decoder = nn.Sequential(
nn.ConvTranspose2d(128, 64, kernel_size=2, stride=2),
nn.ReLU(),
nn.ConvTranspose2d(64, 3, kernel_size=2, stride=2),
nn.Sigmoid(),
)
def forward(self, x):
x = self.encoder(x)
x = self.decoder(x)
return x
class AttentionBlock(nn.Module):
def __init__(self, F_g, F_l, F_int):
super(AttentionBlock, self).__init__()
self.W_g = nn.Sequential(
nn.Conv2d(F_g, F_int, kernel_size=1, stride=1, padding=0, bias=True),
nn.BatchNorm2d(F_int)
)
self.W_x = nn.Sequential(
nn.Conv2d(F_l, F_int, kernel_size=1, stride=1, padding=0, bias=True),
nn.BatchNorm2d(F_int)
)
self.psi = nn.Sequential(
nn.Conv2d(F_int, 1, kernel_size=1, stride=1, padding=0, bias=True),
nn.BatchNorm2d(1),
nn.Sigmoid()
)
self.relu = nn.ReLU(inplace=True)
def forward(self, g, x):
g1 = self.W_g(g)
x1 = self.W_x(x)
psi = self.relu(g1 + x1)
psi = self.psi(psi)
return x * psi
class AttentionUNet(nn.Module):
def __init__(self):
super(AttentionUNet, self).__init__()
self.encoder1 = self.conv_block(3, 64)
self.encoder2 = self.conv_block(64, 128)
self.encoder3 = self.conv_block(128, 256)
self.encoder4 = self.conv_block(256, 512)
self.bottleneck = self.conv_block(512, 1024)
self.upconv4 = nn.ConvTranspose2d(1024, 512, kernel_size=2, stride=2)
self.att4 = AttentionBlock(F_g=512, F_l=512, F_int=256)
self.decoder4 = self.conv_block(1024, 512)
self.upconv3 = nn.ConvTranspose2d(512, 256, kernel_size=2, stride=2)
self.att3 = AttentionBlock(F_g=256, F_l=256, F_int=128)
self.decoder3 = self.conv_block(512, 256)
self.upconv2 = nn.ConvTranspose2d(256, 128, kernel_size=2, stride=2)
self.att2 = AttentionBlock(F_g=128, F_l=128, F_int=64)
self.decoder2 = self.conv_block(256, 128)
self.upconv1 = nn.ConvTranspose2d(128, 64, kernel_size=2, stride=2)
self.att1 = AttentionBlock(F_g=64, F_l=64, F_int=32)
self.decoder1 = self.conv_block(128, 64)
self.final_conv = nn.Conv2d(64, 3, kernel_size=1, stride=1, padding=0)
self.sigmoid = nn.Sigmoid()
def conv_block(self, in_channels, out_channels):
return nn.Sequential(
nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1),
nn.BatchNorm2d(out_channels),
nn.ReLU(inplace=True),
nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1),
nn.BatchNorm2d(out_channels),
nn.ReLU(inplace=True)
)
def forward(self, x):
e1 = self.encoder1(x)
e2 = self.encoder2(F.max_pool2d(e1, 2))
e3 = self.encoder3(F.max_pool2d(e2, 2))
e4 = self.encoder4(F.max_pool2d(e3, 2))
b = self.bottleneck(F.max_pool2d(e4, 2))
d4 = self.upconv4(b)
e4 = self.att4(g=d4, x=e4)
d4 = torch.cat((e4, d4), dim=1)
d4 = self.decoder4(d4)
d3 = self.upconv3(d4)
e3 = self.att3(g=d3, x=e3)
d3 = torch.cat((e3, d3), dim=1)
d3 = self.decoder3(d3)
d2 = self.upconv2(d3)
e2 = self.att2(g=d2, x=e2)
d2 = torch.cat((e2, d2), dim=1)
d2 = self.decoder2(d2)
d1 = self.upconv1(d2)
e1 = self.att1(g=d1, x=e1)
d1 = torch.cat((e1, d1), dim=1)
d1 = self.decoder1(d1)
out = self.final_conv(d1)
out = self.sigmoid(out)
return out
class ColorblindDataset(Dataset):
def __init__(self, image_dir, mode='train', transform=None):
self.image_dir = image_dir
self.mode = mode
self.transform = transform
self.normal_images = glob.glob(os.path.join(image_dir, mode, 'origin_image', '*'))
self.recolor_images = glob.glob(os.path.join(image_dir, mode, 'recolor_image', '*Protanopia*'))
self.correct_images = glob.glob(os.path.join(image_dir, mode, 'correct_image', '*'))
self.normal_images.sort()
self.recolor_images.sort()
self.correct_images.sort()
self.image_pair = []
for index, image in enumerate(self.normal_images):
self.image_pair.append([self.recolor_images[index], self.normal_images[index]])
def __len__(self):
return len(self.image_pair)
def __getitem__(self, idx):
recolor_path, normal_path = self.image_pair[idx]
recolor_image = Image.open(recolor_path).convert('RGB')
normal_image = Image.open(normal_path).convert('RGB')
if self.transform:
recolor_image = self.transform(recolor_image)
normal_image = self.transform(normal_image)
return recolor_image, normal_image
class Generator(nn.Module):
def __init__(self):
super(Generator, self).__init__()
self.encoder1 = self.conv_block(3, 64)
self.encoder2 = self.conv_block(64, 128)
self.encoder3 = self.conv_block(128, 256)
self.encoder4 = self.conv_block(256, 512)
self.bottleneck = self.conv_block(512, 1024)
self.upconv4 = nn.ConvTranspose2d(1024, 512, kernel_size=2, stride=2)
self.att4 = AttentionBlock(F_g=512, F_l=512, F_int=256)
self.decoder4 = self.conv_block(1024, 512)
self.upconv3 = nn.ConvTranspose2d(512, 256, kernel_size=2, stride=2)
self.att3 = AttentionBlock(F_g=256, F_l=256, F_int=128)
self.decoder3 = self.conv_block(512, 256)
self.upconv2 = nn.ConvTranspose2d(256, 128, kernel_size=2, stride=2)
self.att2 = AttentionBlock(F_g=128, F_l=128, F_int=64)
self.decoder2 = self.conv_block(256, 128)
self.upconv1 = nn.ConvTranspose2d(128, 64, kernel_size=2, stride=2)
self.att1 = AttentionBlock(F_g=64, F_l=64, F_int=32)
self.decoder1 = self.conv_block(128, 64)
self.final_conv = nn.Conv2d(64, 3, kernel_size=1, stride=1, padding=0)
self.sigmoid = nn.Sigmoid()
def conv_block(self, in_channels, out_channels):
return nn.Sequential(
nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1),
nn.BatchNorm2d(out_channels),
nn.ReLU(inplace=True),
nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1),
nn.BatchNorm2d(out_channels),
nn.ReLU(inplace=True)
)
def forward(self, x):
e1 = self.encoder1(x)
e2 = self.encoder2(F.max_pool2d(e1, 2))
e3 = self.encoder3(F.max_pool2d(e2, 2))
e4 = self.encoder4(F.max_pool2d(e3, 2))
b = self.bottleneck(F.max_pool2d(e4, 2))
d4 = self.upconv4(b)
e4 = self.att4(g=d4, x=e4)
d4 = torch.cat((e4, d4), dim=1)
d4 = self.decoder4(d4)
d3 = self.upconv3(d4)
e3 = self.att3(g=d3, x=e3)
d3 = torch.cat((e3, d3), dim=1)
d3 = self.decoder3(d3)
d2 = self.upconv2(d3)
e2 = self.att2(g=d2, x=e2)
d2 = torch.cat((e2, d2), dim=1)
d2 = self.decoder2(d2)
d1 = self.upconv1(d2)
e1 = self.att1(g=d1, x=e1)
d1 = torch.cat((e1, d1), dim=1)
d1 = self.decoder1(d1)
out = self.final_conv(d1)
out = self.sigmoid(out)
return out
class Discriminator(nn.Module):
def __init__(self):
super(Discriminator, self).__init__()
self.main = nn.Sequential(
nn.Conv2d(3, 64, kernel_size=4, stride=2, padding=1),
nn.LeakyReLU(0.2, inplace=True),
nn.Conv2d(64, 128, kernel_size=4, stride=2, padding=1),
nn.BatchNorm2d(128),
nn.LeakyReLU(0.2, inplace=True),
nn.Conv2d(128, 256, kernel_size=4, stride=2, padding=1),
nn.BatchNorm2d(256),
nn.LeakyReLU(0.2, inplace=True),
nn.Conv2d(256, 512, kernel_size=4, stride=2, padding=1),
nn.BatchNorm2d(512),
nn.LeakyReLU(0.2, inplace=True),
nn.Conv2d(512, 1, kernel_size=16),
)
def forward(self, x):
return self.main(x).view(-1)
def compute_iou(outputs, targets, threshold=0.5):
outputs = (outputs > threshold).float()
targets = (targets > threshold).float()
intersection = (outputs * targets).sum(dim=(1, 2, 3))
union = outputs.sum(dim=(1, 2, 3)) + targets.sum(dim=(1, 2, 3)) - intersection
iou = (intersection + 1e-6) / (union + 1e-6)
return iou.mean().item()
from skimage.metrics import peak_signal_noise_ratio as psnr_metric
from skimage.metrics import structural_similarity as ssim_metric
def compute_psnr(outputs, targets):
outputs = outputs.cpu().detach().numpy()
targets = targets.cpu().detach().numpy()
psnr = 0
for i in range(outputs.shape[0]):
psnr += psnr_metric(targets[i], outputs[i], data_range=1.0)
return psnr / outputs.shape[0]
def compute_ssim(outputs, targets):
outputs = outputs.cpu().detach().numpy()
targets = targets.cpu().detach().numpy()
ssim = 0
for i in range(outputs.shape[0]):
output_img = outputs[i].transpose(1, 2, 0)
target_img = targets[i].transpose(1, 2, 0)
H, W, _ = output_img.shape
min_dim = min(H, W)
win_size = min(7, min_dim if min_dim % 2 == 1 else min_dim - 1)
win_size = max(win_size, 3)
ssim += ssim_metric(target_img, output_img, data_range=1.0, channel_axis=-1, win_size=win_size)
return ssim / outputs.shape[0]
def wasserstein_loss(pred, target):
return torch.mean(pred * target)
from torch.autograd import grad
def compute_gradient_penalty(discriminator, real_samples, fake_samples, device):
alpha = torch.rand(real_samples.size(0), 1, 1, 1, device=device)
interpolates = (alpha * real_samples + ((1 - alpha) * fake_samples)).requires_grad_(True)
d_interpolates = discriminator(interpolates)
fake = torch.ones(real_samples.size(0), device=device)
gradients = grad(outputs=d_interpolates, inputs=interpolates, grad_outputs=fake, create_graph=True, retain_graph=True, only_inputs=True)[0]
gradients = gradients.view(gradients.size(0), -1)
gradient_penalty = ((gradients.norm(2, dim=1) - 1) ** 2).mean()
return gradient_penalty
def train_correction_model(generator, discriminator, dataloader, optimizer_G, optimizer_D, device, lambda_gp, lambda_pixel, n_critic):
generator.train()
discriminator.train()
running_g_loss = 0.0
running_d_loss = 0.0
running_iou = 0.0
running_psnr = 0.0
running_ssim = 0.0
for batch_idx, (inputs, targets) in enumerate(tqdm(dataloader, desc="Training")):
inputs = inputs.to(device)
targets = targets.to(device)
# Train the discriminator
optimizer_D.zero_grad()
corrected_images = generator(inputs)
real_validity = discriminator(targets)
fake_validity = discriminator(corrected_images.detach())
gp = compute_gradient_penalty(discriminator, targets.data, corrected_images.data, device)
d_loss = -torch.mean(real_validity) + torch.mean(fake_validity) + lambda_gp * gp
d_loss.backward()
optimizer_D.step()
# Train the generator
if batch_idx % n_critic == 0:
optimizer_G.zero_grad()
corrected_images = generator(inputs)
fake_validity = discriminator(corrected_images)
g_adv_loss = -torch.mean(fake_validity)
pixelwise_loss = nn.L1Loss()
g_pixel_loss = pixelwise_loss(corrected_images, targets)
g_loss = g_adv_loss + lambda_pixel * g_pixel_loss
g_loss.backward()
optimizer_G.step()
else:
g_loss = torch.tensor(0.0)
running_g_loss += g_loss.item()
running_d_loss += d_loss.item()
iou = compute_iou(corrected_images, targets)
psnr = compute_psnr(corrected_images, targets)
ssim = compute_ssim(corrected_images, targets)
running_iou += iou
running_psnr +=
阅读全文