给我一份完整的去掉注意力机制的correction_gan的代码
时间: 2024-11-22 18:38:11 浏览: 18
gic.zip_Intensity_gamma_gamma correction_gic_intensity correctio
以下是去掉了注意力机制的 `correction_gan` 代码:
```python
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
from PIL import Image
import os
import argparse
import glob
import matplotlib.pyplot as plt
from tqdm import tqdm
import numpy as np
class UNet(nn.Module):
def __init__(self):
super(UNet, self).__init__()
self.encoder = nn.Sequential(
nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1),
nn.ReLU(),
nn.MaxPool2d(kernel_size=2, stride=2),
nn.Conv2d(64, 128, kernel_size=3, padding=1),
nn.ReLU(),
nn.MaxPool2d(2),
)
self.decoder = nn.Sequential(
nn.ConvTranspose2d(128, 64, kernel_size=2, stride=2),
nn.ReLU(),
nn.ConvTranspose2d(64, 3, kernel_size=2, stride=2),
nn.Sigmoid(),
)
def forward(self, x):
x = self.encoder(x)
x = self.decoder(x)
return x
class Generator(nn.Module):
def __init__(self):
super(Generator, self).__init__()
self.encoder1 = self.conv_block(3, 64)
self.encoder2 = self.conv_block(64, 128)
self.encoder3 = self.conv_block(128, 256)
self.encoder4 = self.conv_block(256, 512)
self.bottleneck = self.conv_block(512, 1024)
self.upconv4 = nn.ConvTranspose2d(1024, 512, kernel_size=2, stride=2)
self.decoder4 = self.conv_block(1024, 512)
self.upconv3 = nn.ConvTranspose2d(512, 256, kernel_size=2, stride=2)
self.decoder3 = self.conv_block(512, 256)
self.upconv2 = nn.ConvTranspose2d(256, 128, kernel_size=2, stride=2)
self.decoder2 = self.conv_block(256, 128)
self.upconv1 = nn.ConvTranspose2d(128, 64, kernel_size=2, stride=2)
self.decoder1 = self.conv_block(128, 64)
self.final_conv = nn.Conv2d(64, 3, kernel_size=1, stride=1, padding=0)
self.sigmoid = nn.Sigmoid()
def conv_block(self, in_channels, out_channels):
return nn.Sequential(
nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1),
nn.BatchNorm2d(out_channels),
nn.ReLU(inplace=True),
nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1),
nn.BatchNorm2d(out_channels),
nn.ReLU(inplace=True)
)
def forward(self, x):
e1 = self.encoder1(x)
e2 = self.encoder2(F.max_pool2d(e1, 2))
e3 = self.encoder3(F.max_pool2d(e2, 2))
e4 = self.encoder4(F.max_pool2d(e3, 2))
b = self.bottleneck(F.max_pool2d(e4, 2))
d4 = self.upconv4(b)
d4 = torch.cat((e4, d4), dim=1)
d4 = self.decoder4(d4)
d3 = self.upconv3(d4)
d3 = torch.cat((e3, d3), dim=1)
d3 = self.decoder3(d3)
d2 = self.upconv2(d3)
d2 = torch.cat((e2, d2), dim=1)
d2 = self.decoder2(d2)
d1 = self.upconv1(d2)
d1 = torch.cat((e1, d1), dim=1)
d1 = self.decoder1(d1)
out = self.final_conv(d1)
out = self.sigmoid(out)
return out
class Discriminator(nn.Module):
def __init__(self):
super(Discriminator, self).__init__()
self.main = nn.Sequential(
nn.Conv2d(3, 64, kernel_size=4, stride=2, padding=1),
nn.LeakyReLU(0.2, inplace=True),
nn.Conv2d(64, 128, kernel_size=4, stride=2, padding=1),
nn.BatchNorm2d(128),
nn.LeakyReLU(0.2, inplace=True),
nn.Conv2d(128, 256, kernel_size=4, stride=2, padding=1),
nn.BatchNorm2d(256),
nn.LeakyReLU(0.2, inplace=True),
nn.Conv2d(256, 512, kernel_size=4, stride=2, padding=1),
nn.BatchNorm2d(512),
nn.LeakyReLU(0.2, inplace=True),
nn.Conv2d(512, 1, kernel_size=16),
)
def forward(self, x):
return self.main(x).view(-1)
def compute_iou(outputs, targets, threshold=0.5):
outputs = (outputs > threshold).float()
targets = (targets > threshold).float()
intersection = (outputs * targets).sum(dim=(1, 2, 3))
union = outputs.sum(dim=(1, 2, 3)) + targets.sum(dim=(1, 2, 3)) - intersection
iou = (intersection + 1e-6) / (union + 1e-6)
return iou.mean().item()
from skimage.metrics import peak_signal_noise_ratio as psnr_metric
from skimage.metrics import structural_similarity as ssim_metric
def compute_psnr(outputs, targets):
outputs = outputs.cpu().detach().numpy()
targets = targets.cpu().detach().numpy()
psnr = 0
for i in range(outputs.shape[0]):
psnr += psnr_metric(targets[i], outputs[i], data_range=1.0)
return psnr / outputs.shape[0]
def compute_ssim(outputs, targets):
outputs = outputs.cpu().detach().numpy()
targets = targets.cpu().detach().numpy()
ssim = 0
for i in range(outputs.shape[0]):
output_img = outputs[i].transpose(1, 2, 0)
target_img = targets[i].transpose(1, 2, 0)
H, W, _ = output_img.shape
min_dim = min(H, W)
win_size = min(7, min_dim if min_dim % 2 == 1 else min_dim - 1)
win_size = max(win_size, 3)
ssim += ssim_metric(target_img, output_img, data_range=1.0, channel_axis=-1, win_size=win_size)
return ssim / outputs.shape[0]
def wasserstein_loss(pred, target):
return torch.mean(pred * target)
from torch.autograd import grad
def compute_gradient_penalty(discriminator, real_samples, fake_samples, device):
alpha = torch.rand(real_samples.size(0), 1, 1, 1, device=device)
interpolates = (alpha * real_samples + ((1 - alpha) * fake_samples)).requires_grad_(True)
d_interpolates = discriminator(interpolates)
fake = torch.ones(real_samples.size(0), device=device)
gradients = grad(outputs=d_interpolates, inputs=interpolates, grad_outputs=fake, create_graph=True, retain_graph=True, only_inputs=True)[0]
gradients = gradients.view(gradients.size(0), -1)
gradient_penalty = ((gradients.norm(2, dim=1) - 1) ** 2).mean()
return gradient_penalty
def train_correction_model(generator, discriminator, dataloader, optimizer_G, optimizer_D, device, lambda_gp, lambda_pixel, n_critic):
generator.train()
discriminator.train()
running_g_loss = 0.0
running_d_loss = 0.0
running_iou = 0.0
running_psnr = 0.0
running_ssim = 0.0
for batch_idx, (inputs, targets) in enumerate(tqdm(dataloader, desc="Training")):
inputs = inputs.to(device)
targets = targets.to(device)
# Train Discriminator
optimizer_D.zero_grad()
corrected_images = generator(inputs)
real_validity = discriminator(targets)
fake_validity = discriminator(corrected_images.detach())
gp = compute_gradient_penalty(discriminator, targets.data, corrected_images.data, device)
d_loss = -torch.mean(real_validity) + torch.mean(fake_validity) + lambda_gp * gp
d_loss.backward()
optimizer_D.step()
# Train Generator
if batch_idx % n_critic == 0:
optimizer_G.zero_grad()
corrected_images = generator(inputs)
fake_validity = discriminator(corrected_images)
g_adv_loss = -torch.mean(fake_validity)
pixelwise_loss = nn.L1Loss()
g_pixel_loss = pixelwise_loss(corrected_images, targets)
g_loss = g_adv_loss + lambda_pixel * g_pixel_loss
g_loss.backward()
optimizer_G.step()
else:
g_loss = torch.tensor(0.0)
running_g_loss += g_loss.item()
running_d_loss += d_loss.item()
iou = compute_iou(corrected_images, targets)
psnr = compute_psnr(corrected_images, targets)
ssim = compute_ssim(corrected_images, targets)
running_iou += iou
running_psnr += psnr
running_ssim += ssim
epoch_g_loss = running_g_loss / len(dataloader)
epoch_d_loss = running_d_loss / len(dataloader)
epoch_iou = running_iou / len(dataloader)
epoch_psnr = running_psnr / len(dataloader)
epoch_ssim = running_ssim / len(dataloader)
return epoch_g_loss, epoch_d_loss, epoch_iou, epoch_psnr, epoch_ssim
def validate_correction_model(generator, discriminator, dataloader, device, lambda_gp):
generator.eval()
discriminator.eval()
running_g_loss = 0.0
running_d_loss = 0.0
running_iou = 0.0
running_psnr = 0.0
running_ssim = 0.0
with torch.no_grad():
for inputs, targets in tqdm(dataloader, desc="Validation"):
inputs = inputs.to(device)
targets = targets.to(device)
corrected_images = generator(inputs)
real_validity = discriminator(targets)
fake_validity = discriminator(corrected_images)
d_loss = -torch.mean(real_validity) + torch.mean(fake_validity)
g_adv_loss = -torch.mean(fake_validity)
g_loss = g_adv_loss
running_g_loss += g_loss.item()
running_d_loss += d_loss.item()
iou = compute_iou(corrected_images, targets)
psnr = compute_psnr(corrected_images, targets)
ssim = compute_ssim(corrected_images, targets)
running_iou += iou
running_psnr += psnr
running_ssim += ssim
epoch_g_loss = running_g_loss / len(dataloader)
epoch_d_loss = running_d_loss / len(dataloader)
epoch_iou = running_iou / len(dataloader)
epoch_psnr = running_psnr / len(dataloader)
epoch_ssim = running_ssim / len(dataloader)
return epoch_g_loss, epoch_d_loss, epoch_iou, epoch_psnr, epoch_ssim
def visualize_results(generator, dataloader, device, num_images=10, save_path='./results'):
generator.eval()
inputs, targets = next(iter(dataloader))
inputs = inputs.to(device)
targets = targets.to(device)
with torch.no_grad():
corrected_images = generator(inputs)
inputs = inputs.cpu().numpy()
targets = targets.cpu().numpy()
corrected_images = corrected_images.cpu().numpy()
if not os.path.exists(save_path):
os.makedirs(save_path)
plt.figure(figsize=(20, 10))
for i in range(num_images):
plt.subplot(3, num_images, i + 1)
plt.imshow(targets[i].transpose(1, 2, 0))
plt.title("Original")
plt.axis('off')
plt.subplot(3, num_images, i + 1 + num_images)
plt.imshow(inputs[i].transpose(1, 2, 0))
plt.title("Simulated Colorblind")
plt.axis('off')
plt.subplot(3, num_images, i + 1 + 2 * num_images)
plt.imshow(corrected_images[i].transpose(1, 2, 0))
plt.title("Corrected")
plt.axis('off')
plt.tight_layout()
plt.savefig(f'{save_path}_visualization.png')
plt.show()
def plot_and_save_metrics(train_metrics, val_metrics, epoch, path='./metrics_plots'):
if not os.path.exists(path):
os.makedirs(path)
epochs = np.arange(1, epoch + 1)
train_g_losses, train_d_losses, train_ious, train_psnrs, train_ssims = zip(*train_metrics)
val_g_losses, val_d_losses, val_ious, val_psnrs, val_ssims = zip(*val_metrics)
plt.figure()
plt.plot(epochs, train_g_losses, label='Training Generator Loss')
plt.plot(epochs, val_g_losses, label='Validation Generator Loss')
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.title('Generator Loss over Epochs')
plt.legend()
plt.savefig(f'{path}/generator_loss.png')
plt.close()
plt.figure()
plt.plot(epochs, train_d_losses, label='Training Discriminator Loss')
plt.plot(epochs, val_d_losses, label='Validation Discriminator Loss')
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.title('Discriminator Loss over Epochs')
plt.legend()
plt.savefig(f'{path}/discriminator_loss.png')
plt.close()
plt.figure()
plt.plot(epochs, train_ious, label='Training IoU')
plt.plot(epochs, val_ious, label='Validation IoU')
plt.xlabel('Epoch')
plt.ylabel('IoU')
plt.title('IoU over Epochs')
plt.legend()
plt.savefig(f'{path}/iou.png')
plt.close()
plt.figure()
plt.plot(epochs, train_psnrs, label='Training PSNR')
plt.plot(epochs, val_psnrs, label='Validation PSNR')
plt.xlabel('Epoch')
plt.ylabel('PSNR')
plt.title('PSNR over Epochs')
plt.legend()
plt.savefig(f'{path}/psnr.png')
plt.
阅读全文