gan 训练代码
时间: 2023-07-03 16:24:23 浏览: 144
以下是一个简单的GAN训练代码的示例,其中包括了生成器和判别器的定义、损失函数的定义、优化器的设置以及具体的训练过程。请注意,这个代码仅供参考,实际应用中可能需要根据具体的问题进行调整和修改。
```python
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
# 定义生成器
class Generator(nn.Module):
def __init__(self, input_size=100, hidden_size=128, output_size=784):
super(Generator, self).__init__()
self.fc1 = nn.Linear(input_size, hidden_size)
self.fc2 = nn.Linear(hidden_size, output_size)
self.relu = nn.ReLU()
self.tanh = nn.Tanh()
def forward(self, x):
x = self.relu(self.fc1(x))
x = self.tanh(self.fc2(x))
return x
# 定义判别器
class Discriminator(nn.Module):
def __init__(self, input_size=784, hidden_size=128, output_size=1):
super(Discriminator, self).__init__()
self.fc1 = nn.Linear(input_size, hidden_size)
self.fc2 = nn.Linear(hidden_size, output_size)
self.sigmoid = nn.Sigmoid()
def forward(self, x):
x = self.sigmoid(self.fc1(x))
x = self.sigmoid(self.fc2(x))
return x
# 定义损失函数
criterion = nn.BCELoss()
# 定义训练数据集
train_dataset = datasets.MNIST(root='./data', train=True, transform=transforms.ToTensor(), download=True)
train_loader = DataLoader(train_dataset, batch_size=128, shuffle=True)
# 定义生成器和判别器
generator = Generator()
discriminator = Discriminator()
# 定义优化器
optimizer_g = optim.Adam(generator.parameters(), lr=0.0002)
optimizer_d = optim.Adam(discriminator.parameters(), lr=0.0002)
# 定义训练过程
def train(num_epochs):
for epoch in range(num_epochs):
for i, (real_data, _) in enumerate(train_loader):
# 训练判别器
discriminator.zero_grad()
real_data = real_data.view(-1, 784)
real_labels = torch.ones(real_data.size(0), 1)
fake_labels = torch.zeros(real_data.size(0), 1)
z = torch.randn(real_data.size(0), 100)
fake_data = generator(z).detach()
real_pred = discriminator(real_data)
real_loss = criterion(real_pred, real_labels)
fake_pred = discriminator(fake_data)
fake_loss = criterion(fake_pred, fake_labels)
d_loss = real_loss + fake_loss
d_loss.backward()
optimizer_d.step()
# 训练生成器
generator.zero_grad()
z = torch.randn(real_data.size(0), 100)
fake_data = generator(z)
pred = discriminator(fake_data)
g_loss = criterion(pred, real_labels)
g_loss.backward()
optimizer_g.step()
if i % 100 == 0:
print('Epoch [{}/{}], Step [{}/{}], d_loss: {:.4f}, g_loss: {:.4f}'
.format(epoch, num_epochs, i, len(train_loader), d_loss.item(), g_loss.item()))
# 开始训练
train(num_epochs=100)
```
在这个示例代码中,我们使用了MNIST数据集,生成器和判别器都是简单的全连接网络,使用了Adam优化器和BCELoss损失函数。在训练过程中,我们首先训练判别器,然后训练生成器。可以根据具体的应用场景和数据集进行修改和调整。
阅读全文