PyTorch 生成对抗网络 代码
时间: 2024-06-13 19:06:12 浏览: 188
以下是PyTorch实现生成对抗网络的代码示例:
```python
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms
from torch.autograd import Variable
# 定义判别器
class Discriminator(nn.Module):
def __init__(self):
super(Discriminator, self).__init__()
self.conv1 = nn.Conv2d(1, 64, kernel_size=4, stride=2, padding=1)
self.conv2 = nn.Conv2d(64, 128, kernel_size=4, stride=2, padding=1)
self.conv2_bn = nn.BatchNorm2d(128)
self.conv3 = nn.Conv2d(128, 256, kernel_size=4, stride=2, padding=1)
self.conv3_bn = nn.BatchNorm2d(256)
self.conv4 = nn.Conv2d(256, 512, kernel_size=4, stride=2, padding=1)
self.conv4_bn = nn.BatchNorm2d(512)
self.conv5 = nn.Conv2d(512, 1, kernel_size=4, stride=1, padding=0)
def forward(self, x):
x = nn.LeakyReLU(0.2, inplace=True)(self.conv1(x))
x = nn.LeakyReLU(0.2, inplace=True)(self.conv2_bn(self.conv2(x)))
x = nn.LeakyReLU(0.2, inplace=True)(self.conv3_bn(self.conv3(x)))
x = nn.LeakyReLU(0.2, inplace=True)(self.conv4_bn(self.conv4(x)))
x = nn.Sigmoid()(self.conv5(x))
return x
# 定义生成器
class Generator(nn.Module):
def __init__(self):
super(Generator, self).__init__()
self.deconv1 = nn.ConvTranspose2d(100, 512, kernel_size=4, stride=1, padding=0)
self.deconv1_bn = nn.BatchNorm2d(512)
self.deconv2 = nn.ConvTranspose2d(512, 256, kernel_size=4, stride=2, padding=1)
self.deconv2_bn = nn.BatchNorm2d(256)
self.deconv3 = nn.ConvTranspose2d(256, 128, kernel_size=4, stride=2, padding=1)
self.deconv3_bn = nn.BatchNorm2d(128)
self.deconv4 = nn.ConvTranspose2d(128, 64, kernel_size=4, stride=2, padding=1)
self.deconv4_bn = nn.BatchNorm2d(64)
self.deconv5 = nn.ConvTranspose2d(64, 1, kernel_size=4, stride=2, padding=1)
def forward(self, x):
x = nn.ReLU(inplace=True)(self.deconv1_bn(self.deconv1(x)))
x = nn.ReLU(inplace=True)(self.deconv2_bn(self.deconv2(x)))
x = nn.ReLU(inplace=True)(self.deconv3_bn(self.deconv3(x)))
x = nn.ReLU(inplace=True)(self.deconv4_bn(self.deconv4(x)))
x = nn.Tanh()(self.deconv5(x))
return x
# 定义训练函数
def train():
# 定义超参数
batch_size = 128
lr = 0.0002
train_epoch = 20
# 加载数据集
transform = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize(mean=(0.5,), std=(0.5,))
])
train_loader = torch.utils.data.DataLoader(
datasets.MNIST('data', train=True, download=True, transform=transform),
batch_size=batch_size, shuffle=True)
# 定义模型、优化器、损失函数
D = Discriminator()
G = Generator()
D_optimizer = optim.Adam(D.parameters(), lr=lr)
G_optimizer = optim.Adam(G.parameters(), lr=lr)
criterion = nn.BCELoss()
# 开始训练
for epoch in range(train_epoch):
for i, (data, target) in enumerate(train_loader):
# 训练判别器
D.zero_grad()
real_data = Variable(data)
real_label = Variable(torch.ones(batch_size))
fake_label = Variable(torch.zeros(batch_size))
# 计算真实数据的损失
real_output = D(real_data)
D_loss_real = criterion(real_output, real_label)
# 计算生成数据的损失
z = Variable(torch.randn(batch_size, 100, 1, 1))
fake_data = G(z)
fake_output = D(fake_data)
D_loss_fake = criterion(fake_output, fake_label)
# 计算总损失
D_loss = D_loss_real + D_loss_fake
D_loss.backward()
D_optimizer.step()
# 训练生成器
G.zero_grad()
z = Variable(torch.randn(batch_size, 100, 1, 1))
fake_data = G(z)
fake_output = D(fake_data)
G_loss = criterion(fake_output, real_label)
G_loss.backward()
G_optimizer.step()
# 输出损失
if i % 100 == 0:
print('Epoch [%d/%d], Step[%d/%d], D_loss: %.4f, G_loss: %.4f'
% (epoch, train_epoch, i, len(train_loader), D_loss.data, G_loss.data))
# 开始训练
train()
```
阅读全文