写一段pytorch生成对抗网络代码
时间: 2023-03-03 16:39:21 浏览: 187
机器学习生成对抗网络(附代码)
5星 · 资源好评率100%
以下是一个简单的 PyTorch 生成对抗网络(GAN)的示例代码:
```
import torch
import torch.nn as nn
import torch.optim as optim
# 定义判别器网络
class Discriminator(nn.Module):
def __init__(self):
super(Discriminator, self).__init__()
self.model = nn.Sequential(
nn.Linear(784, 512),
nn.LeakyReLU(0.2),
nn.Dropout(0.3),
nn.Linear(512, 256),
nn.LeakyReLU(0.2),
nn.Dropout(0.3),
nn.Linear(256, 1),
nn.Sigmoid()
)
def forward(self, x):
x = x.view(-1, 784)
out = self.model(x)
return out
# 定义生成器网络
class Generator(nn.Module):
def __init__(self):
super(Generator, self).__init__()
self.model = nn.Sequential(
nn.Linear(100, 256),
nn.LeakyReLU(0.2),
nn.Linear(256, 512),
nn.LeakyReLU(0.2),
nn.Linear(512, 784),
nn.Tanh()
)
def forward(self, x):
out = self.model(x)
out = out.view(-1, 1, 28, 28)
return out
# 初始化网络和优化器
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
discriminator = Discriminator().to(device)
generator = Generator().to(device)
optimizer_D = optim.Adam(discriminator.parameters(), lr=0.0002)
optimizer_G = optim.Adam(generator.parameters(), lr=0.0002)
# 定义损失函数
criterion = nn.BCELoss()
# 训练 GAN
for epoch in range(200):
for i, data in enumerate(dataloader, 0):
# 训练判别器
discriminator.zero_grad()
real = data[0].to(device)
batch_size = real.size(0)
label = torch.full((batch_size,), 1, device=device)
output = discriminator(real).view(-1)
error_D_real = criterion(output, label)
error_D_real.backward()
noise = torch.randn(batch_size, 100, device=device)
fake = generator(noise)
label.fill_(0)
output = discriminator(fake.detach()).view(-1)
error_D_fake = criterion(output, label)
error_D_fake.backward()
error_D = error_D_real + error_D_fake
optimizer_D.step()
# 训练生成器
generator.zero_grad()
label.fill_(1)
output = discriminator(fake).view(-1)
error_G = criterion(output, label)
error_G.backward()
optimizer_G.step()
# 输出损失值
if i % 100 == 0:
print('[%d/%d][%d/%d] Loss_D: %.4f Loss_G: %.4f'
% (epoch, 200, i, len(dataloader),
error_D.item(), error_G.item()))
```
这段代码定义了一个基于 MNIST 数据集的 GAN,包括一个判别器网络和一个生成器网络。在训练过程中,判别器和生成器分别被训练,通过交替训练来优化模型。训
阅读全文