GAN实例pytorch
时间: 2024-01-07 12:05:30 浏览: 144
以下是一个使用PyTorch实现的GAN(生成对抗网络)的示例代码:
```python
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
import torchvision.transforms as transforms
# 定义生成器网络
class Generator(nn.Module):
def __init__(self):
super(Generator, self).__init__()
self.fc = nn.Linear(100, 256)
self.relu = nn.ReLU()
self.fc2 = nn.Linear(256, 784)
self.tanh = nn.Tanh()
def forward(self, x):
x = self.fc(x)
x = self.relu(x)
x = self.fc2(x)
x = self.tanh(x)
return x
# 定义判别器网络
class Discriminator(nn.Module):
def __init__(self):
super(Discriminator, self).__init__()
self.fc = nn.Linear(784, 256)
self.relu = nn.ReLU()
self.fc2 = nn.Linear(256, 1)
self.sigmoid = nn.Sigmoid()
def forward(self, x):
x = self.fc(x)
x = self.relu(x)
x = self.fc2(x)
x = self.sigmoid(x)
return x
# 定义训练函数
def train_gan():
# 设置随机种子
torch.manual_seed(1)
# 加载MNIST数据集
transform = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.5,), (0.5,))
])
trainset = torchvision.datasets.MNIST(root='./data', train=True, download=True, transform=transform)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=64, shuffle=True)
# 初始化生成器和判别器
generator = Generator()
discriminator = Discriminator()
# 定义损失函数和优化器
criterion = nn.BCELoss()
optimizer_g = optim.Adam(generator.parameters(), lr=0.0002)
optimizer_d = optim.Adam(discriminator.parameters(), lr=0.0002)
# 开始训练
num_epochs = 50
for epoch in range(num_epochs):
for i, data in enumerate(trainloader):
real_images, _ = data
batch_size = real_images.size(0)
real_labels = torch.ones(batch_size, 1)
fake_labels = torch.zeros(batch_size, 1)
# 训练判别器
discriminator.zero_grad()
real_outputs = discriminator(real_images.view(batch_size, -1))
real_loss = criterion(real_outputs, real_labels)
real_score = torch.mean(real_outputs).item()
noise = torch.randn(batch_size, 100)
fake_images = generator(noise)
fake_outputs = discriminator(fake_images.detach())
fake_loss = criterion(fake_outputs, fake_labels)
fake_score = torch.mean(fake_outputs).item()
d_loss = real_loss + fake_loss
d_loss.backward()
optimizer_d.step()
# 训练生成器
generator.zero_grad()
fake_outputs = discriminator(fake_images)
g_loss = criterion(fake_outputs, real_labels)
g_loss.backward()
optimizer_g.step()
# 打印训练信息
if (i+1) % 200 == 0:
print('Epoch [{}/{}], Step [{}/{}], d_loss: {:.4f}, g_loss: {:.4f}, '
'D(x): {:.2f}, D(G(z)): {:.2f}'
.format(epoch+1, num_epochs, i+1, len(trainloader),
d_loss.item(), g_loss.item(), real_score, fake_score))
# 调用训练函数
train_gan()
```
阅读全文