gan对抗网络实例,并且有完整daim
时间: 2024-05-03 15:21:12 浏览: 140
一个简单的GAN网络实例
3星 · 编辑精心推荐
以下是一个简单的GAN对抗网络实例,使用PyTorch编写。
```
import torch
import torch.nn as nn
import torch.optim as optim
import numpy as np
# 定义生成器网络
class Generator(nn.Module):
def __init__(self, input_size, hidden_size, output_size):
super(Generator, self).__init__()
self.fc1 = nn.Linear(input_size, hidden_size)
self.fc2 = nn.Linear(hidden_size, hidden_size)
self.fc3 = nn.Linear(hidden_size, output_size)
self.activation = nn.ReLU()
def forward(self, x):
x = self.activation(self.fc1(x))
x = self.activation(self.fc2(x))
x = self.fc3(x)
return x
# 定义判别器网络
class Discriminator(nn.Module):
def __init__(self, input_size, hidden_size, output_size):
super(Discriminator, self).__init__()
self.fc1 = nn.Linear(input_size, hidden_size)
self.fc2 = nn.Linear(hidden_size, hidden_size)
self.fc3 = nn.Linear(hidden_size, output_size)
self.activation = nn.ReLU()
def forward(self, x):
x = self.activation(self.fc1(x))
x = self.activation(self.fc2(x))
x = torch.sigmoid(self.fc3(x))
return x
# 定义训练函数
def train_generator(generator, discriminator, optimizer, loss_func, batch_size):
optimizer.zero_grad()
# 生成一批随机噪声
noise = torch.randn(batch_size, generator.input_size)
# 通过生成器生成一批假图像
fake_images = generator(noise)
# 通过判别器判断假图像是否为真
output = discriminator(fake_images)
# 计算生成器的损失函数
loss = loss_func(output, torch.ones(batch_size, 1))
loss.backward()
optimizer.step()
return loss.item()
def train_discriminator(generator, discriminator, optimizer, loss_func, real_images, batch_size):
optimizer.zero_grad()
# 通过判别器判断真图像是否为真
output_real = discriminator(real_images)
# 生成一批随机噪声
noise = torch.randn(batch_size, generator.input_size)
# 通过生成器生成一批假图像
fake_images = generator(noise)
# 通过判别器判断假图像是否为假
output_fake = discriminator(fake_images.detach())
# 计算判别器的损失函数
loss_real = loss_func(output_real, torch.ones(batch_size, 1))
loss_fake = loss_func(output_fake, torch.zeros(batch_size, 1))
loss = loss_real + loss_fake
loss.backward()
optimizer.step()
return loss.item()
# 定义超参数
lr = 0.0002
batch_size = 64
input_size = 100
hidden_size = 128
output_size = 784
num_epochs = 100
# 加载MNIST数据集
train_loader = torch.utils.data.DataLoader(
torchvision.datasets.MNIST('data', train=True, download=True,
transform=torchvision.transforms.Compose([
torchvision.transforms.ToTensor(),
torchvision.transforms.Normalize((0.1307,), (0.3081,))
])),
batch_size=batch_size, shuffle=True)
# 创建生成器和判别器网络
generator = Generator(input_size, hidden_size, output_size)
discriminator = Discriminator(output_size, hidden_size, 1)
# 定义损失函数和优化器
loss_func = nn.BCELoss()
generator_optimizer = optim.Adam(generator.parameters(), lr=lr)
discriminator_optimizer = optim.Adam(discriminator.parameters(), lr=lr)
# 开始训练
for epoch in range(num_epochs):
for i, (images, _) in enumerate(train_loader):
# 将图像展开成一维向量
images = images.view(batch_size, -1)
# 训练判别器
d_loss = train_discriminator(generator, discriminator, discriminator_optimizer, loss_func, images, batch_size)
# 训练生成器
g_loss = train_generator(generator, discriminator, generator_optimizer, loss_func, batch_size)
if i % 100 == 0:
print("Epoch [{}/{}], Step [{}/{}], Discriminator Loss: {:.4f}, Generator Loss: {:.4f}"
.format(epoch, num_epochs, i+1, len(train_loader), d_loss, g_loss))
# 生成一批假图像
noise = torch.randn(batch_size, input_size)
fake_images = generator(noise)
# 将假图像可视化
import matplotlib.pyplot as plt
plt.imshow(fake_images[0].view(28, 28).detach().numpy(), cmap='gray')
```
阅读全文