用pytorch写一个GAN代码
时间: 2023-03-23 17:03:54 浏览: 117
这里提供一个简单的GAN代码,使用PyTorch实现:```python
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision.datasets as dset
import torchvision.transforms as transforms
from torch.utils.data import DataLoader
from torchvision.utils import save_image# 定义生成器
class Generator(nn.Module):
def __init__(self, input_size, hidden_size, output_size):
super(Generator, self).__init__()
self.fc1 = nn.Linear(input_size, hidden_size)
self.fc2 = nn.Linear(hidden_size, hidden_size)
self.fc3 = nn.Linear(hidden_size, output_size)
self.relu = nn.ReLU()
self.tanh = nn.Tanh() def forward(self, x):
x = self.relu(self.fc1(x))
x = self.relu(self.fc2(x))
x = self.tanh(self.fc3(x))
return x# 定义判别器
class Discriminator(nn.Module):
def __init__(self, input_size, hidden_size, output_size):
super(Discriminator, self).__init__()
self.fc1 = nn.Linear(input_size, hidden_size)
self.fc2 = nn.Linear(hidden_size, hidden_size)
self.fc3 = nn.Linear(hidden_size, output_size)
self.relu = nn.ReLU()
self.sigmoid = nn.Sigmoid() def forward(self, x):
x = self.relu(self.fc1(x))
x = self.relu(self.fc2(x))
x = self.sigmoid(self.fc3(x))
return x# 定义超参数
batch_size = 100
input_size = 100
hidden_size = 256
output_size = 784
num_epochs = 200# 加载MNIST数据集
transform = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize(mean=(.5,), std=(.5,))
])
train_dataset = dset.MNIST(root='./data', train=True, transform=transform, download=True)
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)# 初始化生成器和判别器
G = Generator(input_size, hidden_size, output_size)
D = Discriminator(output_size, hidden_size, 1)# 定义损失函数和优化器
criterion = nn.BCELoss()
G_optimizer = optim.Adam(G.parameters(), lr=.0002)
D_optimizer = optim.Adam(D.parameters(), lr=.0002)# 训练GAN模型
for epoch in range(num_epochs):
for i, (images, _) in enumerate(train_loader):
# 训练判别器
real_labels = torch.ones(batch_size, 1)
fake_labels = torch.zeros(batch_size, 1)
real_images = images.view(batch_size, -1)
z = torch.randn(batch_size, input_size)
fake_images = G(z) D_real_loss = criterion(D(real_images), real_labels)
D_fake_loss = criterion(D(fake_images), fake_labels)
D_loss = D_real_loss + D_fake_loss
D_optimizer.zero_grad()
D_loss.backward()
D_optimizer.step() # 训练生成器
z = torch.randn(batch_size, input_size)
fake_images = G(z)
G_loss = criterion(D(fake_images), real_labels)
G_optimizer.zero_grad()
G_loss.backward()
G_optimizer.step() # 打印损失值
if (i+1) % 100 == :
print('Epoch [{}/{}], Step [{}/{}], D_loss: {:.4f}, G_loss: {:.4f}'
.format(epoch+1, num_epochs, i+1, len(train_loader), D_loss.item(), G_loss.item())) # 保存生成的图像
if (epoch+1) % 10 == :
fake_images = fake_images.view(batch_size, 1, 28, 28)
save_image(fake_images, './samples/sample_{}.png'.format(epoch+1))
```这个代码实现了一个简单的GAN模型,用于生成手写数字图像。其中,生成器和判别器都是简单的全连接神经网络,损失函数采用二元交叉熵,优化器采用Adam算法。在训练过程中,每隔一定的步数就会打印出当前的损失值,并保存生成的图像。
阅读全文