编写GAN代码生成图像
时间: 2024-04-02 07:11:31 浏览: 10
好的,下面是一个简单的GAN模型代码,用于生成手写数字图像。
首先,我们定义生成器和判别器的网络结构:
```python
import torch
import torch.nn as nn
# 定义生成器
class Generator(nn.Module):
def __init__(self, input_size, hidden_size, output_size):
super(Generator, self).__init__()
self.fc1 = nn.Linear(input_size, hidden_size)
self.fc2 = nn.Linear(hidden_size, hidden_size)
self.fc3 = nn.Linear(hidden_size, output_size)
self.relu = nn.ReLU()
self.sigmoid = nn.Sigmoid()
def forward(self, x):
out = self.fc1(x)
out = self.relu(out)
out = self.fc2(out)
out = self.relu(out)
out = self.fc3(out)
out = self.sigmoid(out)
return out
# 定义判别器
class Discriminator(nn.Module):
def __init__(self, input_size, hidden_size, output_size):
super(Discriminator, self).__init__()
self.fc1 = nn.Linear(input_size, hidden_size)
self.fc2 = nn.Linear(hidden_size, hidden_size)
self.fc3 = nn.Linear(hidden_size, output_size)
self.relu = nn.ReLU()
self.sigmoid = nn.Sigmoid()
def forward(self, x):
out = self.fc1(x)
out = self.relu(out)
out = self.fc2(out)
out = self.relu(out)
out = self.fc3(out)
out = self.sigmoid(out)
return out
```
然后,我们定义一些超参数和优化器:
```python
# 定义超参数
input_size = 100
hidden_size = 128
output_size = 784
batch_size = 100
num_epochs = 100
# 定义优化器
G = Generator(input_size, hidden_size, output_size)
D = Discriminator(output_size, hidden_size, 1)
criterion = nn.BCELoss()
G_optimizer = torch.optim.Adam(G.parameters(), lr=0.0002)
D_optimizer = torch.optim.Adam(D.parameters(), lr=0.0002)
```
接下来,我们定义训练过程:
```python
# 定义训练过程
for epoch in range(num_epochs):
for i, (images, _) in enumerate(train_loader):
# 定义真实样本和噪声样本
real_images = images.view(batch_size, -1)
noise = torch.randn(batch_size, input_size)
# 训练判别器
D_real = D(real_images)
D_fake = D(G(noise))
D_loss_real = criterion(D_real, torch.ones(batch_size, 1))
D_loss_fake = criterion(D_fake, torch.zeros(batch_size, 1))
D_loss = D_loss_real + D_loss_fake
D_optimizer.zero_grad()
D_loss.backward()
D_optimizer.step()
# 训练生成器
noise = torch.randn(batch_size, input_size)
G_fake = G(noise)
D_fake = D(G_fake)
G_loss = criterion(D_fake, torch.ones(batch_size, 1))
G_optimizer.zero_grad()
G_loss.backward()
G_optimizer.step()
if (i+1) % 100 == 0:
print('Epoch [{}/{}], Step [{}/{}], D_loss: {:.4f}, G_loss: {:.4f}'
.format(epoch+1, num_epochs, i+1, len(train_loader), D_loss.item(), G_loss.item()))
```
最后,我们可以使用生成器生成一些手写数字图像:
```python
# 生成手写数字图像
num_test_samples = 16
noise = torch.randn(num_test_samples, input_size)
test_images = G(noise)
test_images = test_images.view(num_test_samples, 1, 28, 28).data
for i in range(num_test_samples):
plt.imshow(test_images[i].numpy().squeeze(), cmap='gray')
plt.show()
```
这样,我们就完成了一个简单的GAN模型,用于生成手写数字图像。