你能帮我写一个pytorch的GAN生成网络吗
时间: 2023-05-27 21:06:52 浏览: 144
当然可以,以下是一个简单的PyTorch GAN生成网络的示例代码,它用于生成手写数字的图像:
```python
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision.datasets as datasets
import torchvision.transforms as transforms
from torchvision.utils import save_image
# 定义超参数
latent_size = 100
hidden_size = 256
image_size = 784
batch_size = 64
num_epochs = 100
# 定义生成器
class Generator(nn.Module):
def __init__(self):
super(Generator, self).__init__()
self.fc1 = nn.Linear(latent_size, hidden_size)
self.fc2 = nn.Linear(hidden_size, hidden_size)
self.fc3 = nn.Linear(hidden_size, image_size)
self.relu = nn.ReLU()
self.tanh = nn.Tanh()
def forward(self, x):
x = self.relu(self.fc1(x))
x = self.relu(self.fc2(x))
x = self.tanh(self.fc3(x))
return x
# 定义判别器
class Discriminator(nn.Module):
def __init__(self):
super(Discriminator, self).__init__()
self.fc1 = nn.Linear(image_size, hidden_size)
self.fc2 = nn.Linear(hidden_size, hidden_size)
self.fc3 = nn.Linear(hidden_size, 1)
self.relu = nn.ReLU()
self.sigmoid = nn.Sigmoid()
def forward(self, x):
x = self.relu(self.fc1(x))
x = self.relu(self.fc2(x))
x = self.sigmoid(self.fc3(x))
return x
# 加载数据集
transform = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.5,), (0.5,))
])
train_dataset = datasets.MNIST(root='./data', train=True, transform=transform, download=True)
train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
# 初始化生成器和判别器
generator = Generator()
discriminator = Discriminator()
# 定义损失函数和优化器
criterion = nn.BCELoss()
g_optimizer = optim.Adam(generator.parameters(), lr=0.0002, betas=(0.5, 0.999))
d_optimizer = optim.Adam(discriminator.parameters(), lr=0.0002, betas=(0.5, 0.999))
# 训练模型
for epoch in range(num_epochs):
for i, (images, _) in enumerate(train_loader):
# 训练判别器
real_labels = torch.ones(batch_size, 1)
fake_labels = torch.zeros(batch_size, 1)
# 计算判别器对真实样本的损失
real_images = images.view(-1, image_size)
real_outputs = discriminator(real_images)
d_loss_real = criterion(real_outputs, real_labels)
# 计算判别器对生成样本的损失
noise = torch.randn(batch_size, latent_size)
fake_images = generator(noise)
fake_outputs = discriminator(fake_images.detach())
d_loss_fake = criterion(fake_outputs, fake_labels)
# 计算判别器总损失并更新参数
d_loss = d_loss_real + d_loss_fake
discriminator.zero_grad()
d_loss.backward()
d_optimizer.step()
# 训练生成器
# 计算生成器对判别器的损失
fake_outputs = discriminator(fake_images)
g_loss = criterion(fake_outputs, real_labels)
# 更新生成器的参数
generator.zero_grad()
g_loss.backward()
g_optimizer.step()
# 打印损失
if (i+1) % 100 == 0:
print(f"Epoch [{epoch+1}/{num_epochs}], Step [{i+1}/{len(train_loader)}], d_loss: {d_loss.item():.4f}, g_loss: {g_loss.item():.4f}")
# 保存生成器的输出图像
if (epoch+1) % 10 == 0:
fake_images = generator(torch.randn(batch_size, latent_size))
save_image(fake_images.view(batch_size, 1, 28, 28), f"gan_images_{epoch+1}.png")
```
在这个示例中,我们定义了一个简单的生成器和判别器,使用MNIST数据集训练了模型,并保存了生成器的输出图像。在每个epoch中,我们首先训练判别器来区分真实和生成的样本,然后训练生成器来欺骗判别器。最后,我们保存生成器的输出图像。
这只是一个简单的示例,你可以根据需要对模型进行调整和改进。
阅读全文