def compute_gradient_penalty(D, real_samples, fake_samples): alpha = torch.from_numpy(np.random.random((real_samples.size()[0], 1, 1, 1))).float().cuda() interpolates = (alpha * real_samples + ((1 - alpha) * fake_samples)).requires_grad_(True) d_interpolates, _ = D(interpolates) fake = autograd.Variable(torch.ones(real_samples.size()[0]), requires_grad=False).cuda() gradients = autograd.grad( outputs=d_interpolates, inputs=interpolates, grad_outputs=fake, create_graph=True, retain_graph=True, only_inputs=True, )[0] gradients = gradients.view(gradients.size(0), -1) gradient_penalty = ((gradients.norm(2, dim=1) - 1) ** 2).mean() return gradient_penalty解释函数
时间: 2024-02-26 21:54:43 浏览: 217
这是一个计算 Wasserstein GAN 梯度惩罚项的函数,用于提高训练的稳定性和生成样本的质量。输入是鉴别器(D)、真实图像样本(real_samples)和生成图像样本(fake_samples),输出是梯度惩罚项(gradient_penalty)。
具体实现如下:
- 随机生成一个(0,1)之间的随机数alpha,其形状为(batch_size, 1, 1, 1)。
- 根据随机数alpha和真实图像样本、生成图像样本,生成一组插值样本(interpolates),其形状与真实图像样本一致。
- 将插值样本(interpolates)输入鉴别器(D)中,得到鉴别器输出(d_interpolates)。
- 构造一个与真实图像样本数目相同的全1张量(fake),作为计算梯度的输出。
- 对鉴别器输出(d_interpolates)关于插值样本(interpolates)的梯度进行计算。
- 将梯度进行reshape,并计算其L2范数(norm)。
- 计算梯度惩罚项(gradient_penalty),即对梯度范数减1后的平方进行平均。
- 返回梯度惩罚项(gradient_penalty)。
相关问题
给我一份完整的去掉注意力机制的correction_gan的代码
以下是去掉了注意力机制的 `correction_gan` 代码:
```python
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
from PIL import Image
import os
import argparse
import glob
import matplotlib.pyplot as plt
from tqdm import tqdm
import numpy as np
class UNet(nn.Module):
def __init__(self):
super(UNet, self).__init__()
self.encoder = nn.Sequential(
nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1),
nn.ReLU(),
nn.MaxPool2d(kernel_size=2, stride=2),
nn.Conv2d(64, 128, kernel_size=3, padding=1),
nn.ReLU(),
nn.MaxPool2d(2),
)
self.decoder = nn.Sequential(
nn.ConvTranspose2d(128, 64, kernel_size=2, stride=2),
nn.ReLU(),
nn.ConvTranspose2d(64, 3, kernel_size=2, stride=2),
nn.Sigmoid(),
)
def forward(self, x):
x = self.encoder(x)
x = self.decoder(x)
return x
class Generator(nn.Module):
def __init__(self):
super(Generator, self).__init__()
self.encoder1 = self.conv_block(3, 64)
self.encoder2 = self.conv_block(64, 128)
self.encoder3 = self.conv_block(128, 256)
self.encoder4 = self.conv_block(256, 512)
self.bottleneck = self.conv_block(512, 1024)
self.upconv4 = nn.ConvTranspose2d(1024, 512, kernel_size=2, stride=2)
self.decoder4 = self.conv_block(1024, 512)
self.upconv3 = nn.ConvTranspose2d(512, 256, kernel_size=2, stride=2)
self.decoder3 = self.conv_block(512, 256)
self.upconv2 = nn.ConvTranspose2d(256, 128, kernel_size=2, stride=2)
self.decoder2 = self.conv_block(256, 128)
self.upconv1 = nn.ConvTranspose2d(128, 64, kernel_size=2, stride=2)
self.decoder1 = self.conv_block(128, 64)
self.final_conv = nn.Conv2d(64, 3, kernel_size=1, stride=1, padding=0)
self.sigmoid = nn.Sigmoid()
def conv_block(self, in_channels, out_channels):
return nn.Sequential(
nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1),
nn.BatchNorm2d(out_channels),
nn.ReLU(inplace=True),
nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1),
nn.BatchNorm2d(out_channels),
nn.ReLU(inplace=True)
)
def forward(self, x):
e1 = self.encoder1(x)
e2 = self.encoder2(F.max_pool2d(e1, 2))
e3 = self.encoder3(F.max_pool2d(e2, 2))
e4 = self.encoder4(F.max_pool2d(e3, 2))
b = self.bottleneck(F.max_pool2d(e4, 2))
d4 = self.upconv4(b)
d4 = torch.cat((e4, d4), dim=1)
d4 = self.decoder4(d4)
d3 = self.upconv3(d4)
d3 = torch.cat((e3, d3), dim=1)
d3 = self.decoder3(d3)
d2 = self.upconv2(d3)
d2 = torch.cat((e2, d2), dim=1)
d2 = self.decoder2(d2)
d1 = self.upconv1(d2)
d1 = torch.cat((e1, d1), dim=1)
d1 = self.decoder1(d1)
out = self.final_conv(d1)
out = self.sigmoid(out)
return out
class Discriminator(nn.Module):
def __init__(self):
super(Discriminator, self).__init__()
self.main = nn.Sequential(
nn.Conv2d(3, 64, kernel_size=4, stride=2, padding=1),
nn.LeakyReLU(0.2, inplace=True),
nn.Conv2d(64, 128, kernel_size=4, stride=2, padding=1),
nn.BatchNorm2d(128),
nn.LeakyReLU(0.2, inplace=True),
nn.Conv2d(128, 256, kernel_size=4, stride=2, padding=1),
nn.BatchNorm2d(256),
nn.LeakyReLU(0.2, inplace=True),
nn.Conv2d(256, 512, kernel_size=4, stride=2, padding=1),
nn.BatchNorm2d(512),
nn.LeakyReLU(0.2, inplace=True),
nn.Conv2d(512, 1, kernel_size=16),
)
def forward(self, x):
return self.main(x).view(-1)
def compute_iou(outputs, targets, threshold=0.5):
outputs = (outputs > threshold).float()
targets = (targets > threshold).float()
intersection = (outputs * targets).sum(dim=(1, 2, 3))
union = outputs.sum(dim=(1, 2, 3)) + targets.sum(dim=(1, 2, 3)) - intersection
iou = (intersection + 1e-6) / (union + 1e-6)
return iou.mean().item()
from skimage.metrics import peak_signal_noise_ratio as psnr_metric
from skimage.metrics import structural_similarity as ssim_metric
def compute_psnr(outputs, targets):
outputs = outputs.cpu().detach().numpy()
targets = targets.cpu().detach().numpy()
psnr = 0
for i in range(outputs.shape[0]):
psnr += psnr_metric(targets[i], outputs[i], data_range=1.0)
return psnr / outputs.shape[0]
def compute_ssim(outputs, targets):
outputs = outputs.cpu().detach().numpy()
targets = targets.cpu().detach().numpy()
ssim = 0
for i in range(outputs.shape[0]):
output_img = outputs[i].transpose(1, 2, 0)
target_img = targets[i].transpose(1, 2, 0)
H, W, _ = output_img.shape
min_dim = min(H, W)
win_size = min(7, min_dim if min_dim % 2 == 1 else min_dim - 1)
win_size = max(win_size, 3)
ssim += ssim_metric(target_img, output_img, data_range=1.0, channel_axis=-1, win_size=win_size)
return ssim / outputs.shape[0]
def wasserstein_loss(pred, target):
return torch.mean(pred * target)
from torch.autograd import grad
def compute_gradient_penalty(discriminator, real_samples, fake_samples, device):
alpha = torch.rand(real_samples.size(0), 1, 1, 1, device=device)
interpolates = (alpha * real_samples + ((1 - alpha) * fake_samples)).requires_grad_(True)
d_interpolates = discriminator(interpolates)
fake = torch.ones(real_samples.size(0), device=device)
gradients = grad(outputs=d_interpolates, inputs=interpolates, grad_outputs=fake, create_graph=True, retain_graph=True, only_inputs=True)[0]
gradients = gradients.view(gradients.size(0), -1)
gradient_penalty = ((gradients.norm(2, dim=1) - 1) ** 2).mean()
return gradient_penalty
def train_correction_model(generator, discriminator, dataloader, optimizer_G, optimizer_D, device, lambda_gp, lambda_pixel, n_critic):
generator.train()
discriminator.train()
running_g_loss = 0.0
running_d_loss = 0.0
running_iou = 0.0
running_psnr = 0.0
running_ssim = 0.0
for batch_idx, (inputs, targets) in enumerate(tqdm(dataloader, desc="Training")):
inputs = inputs.to(device)
targets = targets.to(device)
# Train Discriminator
optimizer_D.zero_grad()
corrected_images = generator(inputs)
real_validity = discriminator(targets)
fake_validity = discriminator(corrected_images.detach())
gp = compute_gradient_penalty(discriminator, targets.data, corrected_images.data, device)
d_loss = -torch.mean(real_validity) + torch.mean(fake_validity) + lambda_gp * gp
d_loss.backward()
optimizer_D.step()
# Train Generator
if batch_idx % n_critic == 0:
optimizer_G.zero_grad()
corrected_images = generator(inputs)
fake_validity = discriminator(corrected_images)
g_adv_loss = -torch.mean(fake_validity)
pixelwise_loss = nn.L1Loss()
g_pixel_loss = pixelwise_loss(corrected_images, targets)
g_loss = g_adv_loss + lambda_pixel * g_pixel_loss
g_loss.backward()
optimizer_G.step()
else:
g_loss = torch.tensor(0.0)
running_g_loss += g_loss.item()
running_d_loss += d_loss.item()
iou = compute_iou(corrected_images, targets)
psnr = compute_psnr(corrected_images, targets)
ssim = compute_ssim(corrected_images, targets)
running_iou += iou
running_psnr += psnr
running_ssim += ssim
epoch_g_loss = running_g_loss / len(dataloader)
epoch_d_loss = running_d_loss / len(dataloader)
epoch_iou = running_iou / len(dataloader)
epoch_psnr = running_psnr / len(dataloader)
epoch_ssim = running_ssim / len(dataloader)
return epoch_g_loss, epoch_d_loss, epoch_iou, epoch_psnr, epoch_ssim
def validate_correction_model(generator, discriminator, dataloader, device, lambda_gp):
generator.eval()
discriminator.eval()
running_g_loss = 0.0
running_d_loss = 0.0
running_iou = 0.0
running_psnr = 0.0
running_ssim = 0.0
with torch.no_grad():
for inputs, targets in tqdm(dataloader, desc="Validation"):
inputs = inputs.to(device)
targets = targets.to(device)
corrected_images = generator(inputs)
real_validity = discriminator(targets)
fake_validity = discriminator(corrected_images)
d_loss = -torch.mean(real_validity) + torch.mean(fake_validity)
g_adv_loss = -torch.mean(fake_validity)
g_loss = g_adv_loss
running_g_loss += g_loss.item()
running_d_loss += d_loss.item()
iou = compute_iou(corrected_images, targets)
psnr = compute_psnr(corrected_images, targets)
ssim = compute_ssim(corrected_images, targets)
running_iou += iou
running_psnr += psnr
running_ssim += ssim
epoch_g_loss = running_g_loss / len(dataloader)
epoch_d_loss = running_d_loss / len(dataloader)
epoch_iou = running_iou / len(dataloader)
epoch_psnr = running_psnr / len(dataloader)
epoch_ssim = running_ssim / len(dataloader)
return epoch_g_loss, epoch_d_loss, epoch_iou, epoch_psnr, epoch_ssim
def visualize_results(generator, dataloader, device, num_images=10, save_path='./results'):
generator.eval()
inputs, targets = next(iter(dataloader))
inputs = inputs.to(device)
targets = targets.to(device)
with torch.no_grad():
corrected_images = generator(inputs)
inputs = inputs.cpu().numpy()
targets = targets.cpu().numpy()
corrected_images = corrected_images.cpu().numpy()
if not os.path.exists(save_path):
os.makedirs(save_path)
plt.figure(figsize=(20, 10))
for i in range(num_images):
plt.subplot(3, num_images, i + 1)
plt.imshow(targets[i].transpose(1, 2, 0))
plt.title("Original")
plt.axis('off')
plt.subplot(3, num_images, i + 1 + num_images)
plt.imshow(inputs[i].transpose(1, 2, 0))
plt.title("Simulated Colorblind")
plt.axis('off')
plt.subplot(3, num_images, i + 1 + 2 * num_images)
plt.imshow(corrected_images[i].transpose(1, 2, 0))
plt.title("Corrected")
plt.axis('off')
plt.tight_layout()
plt.savefig(f'{save_path}_visualization.png')
plt.show()
def plot_and_save_metrics(train_metrics, val_metrics, epoch, path='./metrics_plots'):
if not os.path.exists(path):
os.makedirs(path)
epochs = np.arange(1, epoch + 1)
train_g_losses, train_d_losses, train_ious, train_psnrs, train_ssims = zip(*train_metrics)
val_g_losses, val_d_losses, val_ious, val_psnrs, val_ssims = zip(*val_metrics)
plt.figure()
plt.plot(epochs, train_g_losses, label='Training Generator Loss')
plt.plot(epochs, val_g_losses, label='Validation Generator Loss')
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.title('Generator Loss over Epochs')
plt.legend()
plt.savefig(f'{path}/generator_loss.png')
plt.close()
plt.figure()
plt.plot(epochs, train_d_losses, label='Training Discriminator Loss')
plt.plot(epochs, val_d_losses, label='Validation Discriminator Loss')
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.title('Discriminator Loss over Epochs')
plt.legend()
plt.savefig(f'{path}/discriminator_loss.png')
plt.close()
plt.figure()
plt.plot(epochs, train_ious, label='Training IoU')
plt.plot(epochs, val_ious, label='Validation IoU')
plt.xlabel('Epoch')
plt.ylabel('IoU')
plt.title('IoU over Epochs')
plt.legend()
plt.savefig(f'{path}/iou.png')
plt.close()
plt.figure()
plt.plot(epochs, train_psnrs, label='Training PSNR')
plt.plot(epochs, val_psnrs, label='Validation PSNR')
plt.xlabel('Epoch')
plt.ylabel('PSNR')
plt.title('PSNR over Epochs')
plt.legend()
plt.savefig(f'{path}/psnr.png')
plt.
阅读全文