cycle GAN代码
时间: 2023-07-02 09:10:37 浏览: 116
这里是一个简单的CycleGAN代码示例,使用PyTorch实现。CycleGAN是一种无监督的图像转换模型,可以将一类图像转换为另一类图像。
```python
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision.transforms as transforms
from torch.utils.data import Dataset, DataLoader
from PIL import Image
class CycleGANDataset(Dataset):
def __init__(self, root_dir, transform=None):
self.root_dir = root_dir
self.img_names = os.listdir(root_dir)
self.transform = transform
def __len__(self):
return len(self.img_names)
def __getitem__(self, idx):
img_name = self.img_names[idx]
img_path = os.path.join(self.root_dir, img_name)
image = Image.open(img_path).convert('RGB')
if self.transform:
image = self.transform(image)
return image
class Generator(nn.Module):
def __init__(self):
super(Generator, self).__init__()
self.encoder = nn.Sequential(
nn.Conv2d(3, 64, kernel_size=4, stride=2, padding=1),
nn.ReLU(inplace=True),
nn.Conv2d(64, 128, kernel_size=4, stride=2, padding=1),
nn.BatchNorm2d(128),
nn.ReLU(inplace=True),
nn.Conv2d(128, 256, kernel_size=4, stride=2, padding=1),
nn.BatchNorm2d(256),
nn.ReLU(inplace=True),
nn.Conv2d(256, 512, kernel_size=4, stride=2, padding=1),
nn.BatchNorm2d(512),
nn.ReLU(inplace=True),
nn.Conv2d(512, 512, kernel_size=4, stride=2, padding=1),
nn.BatchNorm2d(512),
nn.ReLU(inplace=True),
nn.Conv2d(512, 512, kernel_size=4, stride=2, padding=1),
nn.BatchNorm2d(512),
nn.ReLU(inplace=True)
)
self.decoder = nn.Sequential(
nn.ConvTranspose2d(512, 512, kernel_size=4, stride=2, padding=1),
nn.BatchNorm2d(512),
nn.ReLU(inplace=True),
nn.ConvTranspose2d(512, 512, kernel_size=4, stride=2, padding=1),
nn.BatchNorm2d(512),
nn.ReLU(inplace=True),
nn.ConvTranspose2d(512, 256, kernel_size=4, stride=2, padding=1),
nn.BatchNorm2d(256),
nn.ReLU(inplace=True),
nn.ConvTranspose2d(256, 128, kernel_size=4, stride=2, padding=1),
nn.BatchNorm2d(128),
nn.ReLU(inplace=True),
nn.ConvTranspose2d(128, 64, kernel_size=4, stride=2, padding=1),
nn.BatchNorm2d(64),
nn.ReLU(inplace=True),
nn.ConvTranspose2d(64, 3, kernel_size=4, stride=2, padding=1),
nn.Tanh()
)
def forward(self, x):
encoded = self.encoder(x)
decoded = self.decoder(encoded)
return decoded
class Discriminator(nn.Module):
def __init__(self):
super(Discriminator, self).__init__()
self.conv1 = nn.Conv2d(3, 64, kernel_size=4, stride=2, padding=1)
self.conv2 = nn.Conv2d(64, 128, kernel_size=4, stride=2, padding=1)
self.bn2 = nn.BatchNorm2d(128)
self.conv3 = nn.Conv2d(128, 256, kernel_size=4, stride=2, padding=1)
self.bn3 = nn.BatchNorm2d(256)
self.conv4 = nn.Conv2d(256, 512, kernel_size=4, stride=1, padding=1)
self.bn4 = nn.BatchNorm2d(512)
self.conv5 = nn.Conv2d(512, 1, kernel_size=4, stride=1, padding=1)
def forward(self, x):
x = nn.functional.leaky_relu(self.conv1(x), 0.2, inplace=True)
x = nn.functional.leaky_relu(self.bn2(self.conv2(x)), 0.2, inplace=True)
x = nn.functional.leaky_relu(self.bn3(self.conv3(x)), 0.2, inplace=True)
x = nn.functional.leaky_relu(self.bn4(self.conv4(x)), 0.2, inplace=True)
x = torch.sigmoid(self.conv5(x))
return x
class CycleGAN():
def __init__(self, device, lr=0.0002, lambda_cycle=10):
self.device = device
self.generator_A = Generator().to(device)
self.generator_B = Generator().to(device)
self.discriminator_A = Discriminator().to(device)
self.discriminator_B = Discriminator().to(device)
self.optimizer_G = optim.Adam(list(self.generator_A.parameters()) + list(self.generator_B.parameters()), lr=lr, betas=(0.5, 0.999))
self.optimizer_D = optim.Adam(list(self.discriminator_A.parameters()) + list(self.discriminator_B.parameters()), lr=lr, betas=(0.5, 0.999))
self.criterion_cycle = nn.L1Loss()
self.criterion_adv = nn.BCELoss()
self.lambda_cycle = lambda_cycle
def train(self, dataloader_A, dataloader_B, num_epochs):
self.generator_A.train()
self.generator_B.train()
self.discriminator_A.train()
self.discriminator_B.train()
for epoch in range(num_epochs):
for i, (real_A, real_B) in enumerate(zip(dataloader_A, dataloader_B)):
# Move data to device
real_A, real_B = real_A.to(self.device), real_B.to(self.device)
# Train generators
self.optimizer_G.zero_grad()
# Adversarial loss
fake_B = self.generator_A(real_A)
pred_fake_B = self.discriminator_B(fake_B)
loss_adv_B = self.criterion_adv(pred_fake_B, torch.ones_like(pred_fake_B))
fake_A = self.generator_B(real_B)
pred_fake_A = self.discriminator_A(fake_A)
loss_adv_A = self.criterion_adv(pred_fake_A, torch.ones_like(pred_fake_A))
# Cycle consistency loss
cycle_A = self.generator_B(fake_B)
loss_cycle_A = self.criterion_cycle(cycle_A, real_A)
cycle_B = self.generator_A(fake_A)
loss_cycle_B = self.criterion_cycle(cycle_B, real_B)
# Total generator loss
loss_G = loss_adv_A + loss_adv_B + self.lambda_cycle * (loss_cycle_A + loss_cycle_B)
loss_G.backward()
self.optimizer_G.step()
# Train discriminators
self.optimizer_D.zero_grad()
# Real loss
pred_real_A = self.discriminator_A(real_A)
loss_real_A = self.criterion_adv(pred_real_A, torch.ones_like(pred_real_A))
pred_real_B = self.discriminator_B(real_B)
loss_real_B = self.criterion_adv(pred_real_B, torch.ones_like(pred_real_B))
# Fake loss
pred_fake_A = self.discriminator_A(fake_A.detach())
loss_fake_A = self.criterion_adv(pred_fake_A, torch.zeros_like(pred_fake_A))
pred_fake_B = self.discriminator_B(fake_B.detach())
loss_fake_B = self.criterion_adv(pred_fake_B, torch.zeros_like(pred_fake_B))
# Total discriminator loss
loss_D = (loss_real_A + loss_fake_A + loss_real_B + loss_fake_B) * 0.25
loss_D.backward()
self.optimizer_D.step()
# Print loss
if i % 100 == 0:
print('[Epoch %d/%d] [Batch %d/%d] [D loss: %.4f] [G loss: %.4f]' %
(epoch+1, num_epochs, i+1, min(len(dataloader_A), len(dataloader_B)), loss_D.item(), loss_G.item()))
# Example usage
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
transform = transforms.Compose([transforms.Resize(256), transforms.RandomCrop(256), transforms.ToTensor(), transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])
dataset_A = CycleGANDataset('path/to/dataset_A', transform=transform)
dataset_B = CycleGANDataset('path/to/dataset_B', transform=transform)
dataloader_A = DataLoader(dataset_A, batch_size=1, shuffle=True)
dataloader_B = DataLoader(dataset_B, batch_size=1, shuffle=True)
cyclegan = CycleGAN(device, lr=0.0002, lambda_cycle=10)
cyclegan.train(dataloader_A, dataloader_B, num_epochs=200)
```
阅读全文