gan torch实现
时间: 2023-08-15 10:28:32 浏览: 125
GAN (Generative Adversarial Network) 是一种常用的生成模型,可以用来生成与训练数据类似的样本。下面是使用 PyTorch 实现 GAN 的示例代码:
```python
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
import matplotlib.pyplot as plt
# 定义生成器模型
class Generator(nn.Module):
def __init__(self, input_dim=100, output_dim=784):
super(Generator, self).__init__()
self.fc1 = nn.Linear(input_dim, 256)
self.fc2 = nn.Linear(256, 512)
self.fc3 = nn.Linear(512, 1024)
self.fc4 = nn.Linear(1024, output_dim)
self.relu = nn.ReLU()
self.sigmoid = nn.Sigmoid()
def forward(self, x):
x = self.relu(self.fc1(x))
x = self.relu(self.fc2(x))
x = self.relu(self.fc3(x))
x = self.sigmoid(self.fc4(x))
return x
# 定义判别器模型
class Discriminator(nn.Module):
def __init__(self, input_dim=784, output_dim=1):
super(Discriminator, self).__init__()
self.fc1 = nn.Linear(input_dim, 1024)
self.fc2 = nn.Linear(1024, 512)
self.fc3 = nn.Linear(512, 256)
self.fc4 = nn.Linear(256, output_dim)
self.relu = nn.ReLU()
self.sigmoid = nn.Sigmoid()
def forward(self, x):
x = self.relu(self.fc1(x))
x = self.relu(self.fc2(x))
x = self.relu(self.fc3(x))
x = self.sigmoid(self.fc4(x))
return x
# 定义训练函数
def train(disc_model, gen_model, disc_optimizer, gen_optimizer, criterion, dataloader, device):
disc_model.train()
gen_model.train()
for batch_idx, (real_data, _) in enumerate(dataloader):
real_data = real_data.to(device)
batch_size = real_data.size(0)
# 训练判别器
disc_optimizer.zero_grad()
# 真实数据
real_output = disc_model(real_data)
real_target = torch.ones(batch_size, 1).to(device)
real_loss = criterion(real_output, real_target)
# 生成数据
z = torch.randn(batch_size, 100).to(device)
fake_data = gen_model(z)
fake_output = disc_model(fake_data.detach())
fake_target = torch.zeros(batch_size, 1).to(device)
fake_loss = criterion(fake_output, fake_target)
# 计算总的损失
disc_loss = real_loss + fake_loss
disc_loss.backward()
disc_optimizer.step()
# 训练生成器
gen_optimizer.zero_grad()
z = torch.randn(batch_size, 100).to(device)
fake_data = gen_model(z)
fake_output = disc_model(fake_data)
gen_target = torch.ones(batch_size, 1).to(device)
gen_loss = criterion(fake_output, gen_target)
gen_loss.backward()
gen_optimizer.step()
# 定义测试函数
def test(gen_model, device):
gen_model.eval()
z = torch.randn(64, 100).to(device)
fake_data = gen_model(z)
fake_data = fake_data.view(64, 1, 28, 28)
plt.figure(figsize=(8, 8))
for i in range(64):
plt.subplot(8, 8, i+1)
plt.imshow(fake_data[i].cpu().detach().numpy().squeeze(), cmap='gray')
plt.axis('off')
plt.show()
# 定义训练参数
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
batch_size = 128
lr = 0.0002
num_epoch = 20
input_dim = 100
output_dim = 784
# 加载 MNIST 数据
transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.5,), (0.5,))])
train_dataset = datasets.MNIST(root='./data', train=True, transform=transform, download=True)
train_dataloader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
# 实例化模型和优化器
disc_model = Discriminator().to(device)
gen_model = Generator().to(device)
disc_optimizer = optim.Adam(disc_model.parameters(), lr=lr)
gen_optimizer = optim.Adam(gen_model.parameters(), lr=lr)
criterion = nn.BCELoss()
# 训练模型
for epoch in range(num_epoch):
train(disc_model, gen_model, disc_optimizer, gen_optimizer, criterion, train_dataloader, device)
print('Epoch {} finished.'.format(epoch+1))
# 测试模型
test(gen_model, device)
```
以上代码实现了一个简单的 GAN 模型,使用了 PyTorch 框架和 MNIST 数据集。具体实现过程可以参考代码中的注释和 PyTorch 官方文档。
阅读全文