GAN pytorch代码
时间: 2023-10-22 21:07:34 浏览: 111
以下是一个简单的GAN模型的Pytorch代码示例:
```python
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import transforms, datasets
# 定义生成器模型
class Generator(nn.Module):
def __init__(self, latent_dim=100, hidden_dim=128, output_dim=784):
super().__init__()
self.latent_dim = latent_dim
self.hidden_dim = hidden_dim
self.output_dim = output_dim
self.linear1 = nn.Linear(latent_dim, hidden_dim)
self.linear2 = nn.Linear(hidden_dim, hidden_dim)
self.linear3 = nn.Linear(hidden_dim, output_dim)
self.activation = nn.ReLU()
self.output_activation = nn.Tanh()
def forward(self, x):
x = self.activation(self.linear1(x))
x = self.activation(self.linear2(x))
x = self.output_activation(self.linear3(x))
return x
# 定义鉴别器模型
class Discriminator(nn.Module):
def __init__(self, input_dim=784, hidden_dim=128, output_dim=1):
super().__init__()
self.input_dim = input_dim
self.hidden_dim = hidden_dim
self.output_dim = output_dim
self.linear1 = nn.Linear(input_dim, hidden_dim)
self.linear2 = nn.Linear(hidden_dim, hidden_dim)
self.linear3 = nn.Linear(hidden_dim, output_dim)
self.activation = nn.ReLU()
self.output_activation = nn.Sigmoid()
def forward(self, x):
x = self.activation(self.linear1(x))
x = self.activation(self.linear2(x))
x = self.output_activation(self.linear3(x))
return x
# 定义训练函数
def train_gan(generator, discriminator, dataloader, device, epochs=10, lr=0.0002, latent_dim=100):
criterion = nn.BCELoss()
optimizer_g = optim.Adam(generator.parameters(), lr=lr)
optimizer_d = optim.Adam(discriminator.parameters(), lr=lr)
for epoch in range(epochs):
for i, (real_data, _) in enumerate(dataloader):
real_data = real_data.to(device)
batch_size = real_data.size(0)
labels_real = torch.ones(batch_size, 1).to(device)
labels_fake = torch.zeros(batch_size, 1).to(device)
latent = torch.randn(batch_size, latent_dim).to(device)
fake_data = generator(latent)
# 训练鉴别器
optimizer_d.zero_grad()
output_real = discriminator(real_data)
loss_real = criterion(output_real, labels_real)
output_fake = discriminator(fake_data.detach())
loss_fake = criterion(output_fake, labels_fake)
loss_d = loss_real + loss_fake
loss_d.backward()
optimizer_d.step()
# 训练生成器
optimizer_g.zero_grad()
output_fake = discriminator(fake_data)
loss_g = criterion(output_fake, labels_real)
loss_g.backward()
optimizer_g.step()
if i % 100 == 0:
print(f"Epoch [{epoch+1}/{epochs}], Batch [{i+1}/{len(dataloader)}], Loss D: {loss_d.item()}, Loss G: {loss_g.item()}")
# 加载MNIST数据集
transform = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.5,), (0.5,))
])
trainset = datasets.MNIST('data', train=True, download=True, transform=transform)
dataloader = torch.utils.data.DataLoader(trainset, batch_size=128, shuffle=True)
# 创建模型和优化器
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
generator = Generator().to(device)
discriminator = Discriminator().to(device)
# 训练GAN模型
train_gan(generator, discriminator, dataloader, device, epochs=10, lr=0.0002, latent_dim=100)
```
这是一个简单的GAN模型,可以生成MNIST手写数字图像。
阅读全文