GAN算法python代码实现
时间: 2025-01-03 10:33:45 浏览: 12
### 如何使用Python实现GAN(生成对抗网络)
#### 构建生成器
生成器的目标是从随机噪声中生成逼真的图像。通常,生成器是一个反卷积神经网络。
```python
import torch.nn as nn
class Generator(nn.Module):
def __init__(self, input_dim=100, output_dim=3, hidden_dim=64):
super(Generator, self).__init__()
model = [
*block(input_dim, hidden_dim * 8, normalize=False), # (N, 512, 4, 4)
*block(hidden_dim * 8, hidden_dim * 4), # (N, 256, 8, 8)
*block(hidden_dim * 4, hidden_dim * 2), # (N, 128, 16, 16)
*block(hidden_dim * 2, hidden_dim), # (N, 64, 32, 32)
nn.ConvTranspose2d(hidden_dim, output_dim, kernel_size=4, stride=2, padding=1),
nn.Tanh() # 输出范围[-1, 1]
]
model.append(nn.Tanh())
model.insert(0, nn.ConvTranspose2d(input_dim, hidden_dim*8, 4, 1, 0))
model.insert(1, nn.BatchNorm2d(hidden_dim*8))
self.model = nn.Sequential(*model)
def forward(self, z):
img = self.model(z.view(-1, z.size(1), 1, 1))
return img
def block(in_channels, out_channels, normalize=True):
layers = [nn.ConvTranspose2d(in_channels, out_channels, 4, 2, 1)]
if normalize:
layers.append(nn.BatchNorm2d(out_channels))
layers.append(nn.ReLU(inplace=True))
return layers
```
#### 构建判别器
判别器的作用是区分真实图片和由生成器产生的假图。一般情况下,判别器会设计成一个普通的CNN架构。
```python
class Discriminator(nn.Module):
def __init__(self, input_shape=(3, 128, 128)):
super(Discriminator, self).__init__()
channels, height, width = input_shape
def discriminator_block(in_filters, out_filters, bn=True):
block_elements = [nn.Conv2d(in_filters, out_filters, 3, 2, 1),
nn.LeakyReLU(0.2, inplace=True),
nn.Dropout2d(0.25)]
if bn:
block_elements.append(nn.BatchNorm2d(out_filters, 0.8))
return block_elements
self.model = nn.Sequential(
*discriminator_block(channels, 16, False),
*discriminator_block(16, 32),
*discriminator_block(32, 64),
*discriminator_block(64, 128),
)
ds_size = height // 2 ** 4
self.adv_layer = nn.Linear(128 * ds_size ** 2, 1)
def forward(self, img):
features = self.model(img)
validity = self.adv_layer(features.view(features.size(0), -1))
return validity
```
#### 训练循环
定义损失函数并设置优化器来交替更新生成器和判别器的参数。
```python
import torch.optim as optim
adversarial_loss = nn.BCEWithLogitsLoss()
generator = Generator()
discriminator = Discriminator()
optimizer_G = optim.Adam(generator.parameters(), lr=0.0002, betas=(0.5, 0.999))
optimizer_D = optim.Adam(discriminator.parameters(), lr=0.0002, betas=(0.5, 0.999))
for epoch in range(num_epochs):
for i, (imgs, _) in enumerate(dataloader):
valid = Variable(Tensor(imgs.shape[0], 1).fill_(1.0), requires_grad=False)
fake = Variable(Tensor(imgs.shape[0], 1).fill_(0.0), requires_grad=False)
real_imgs = Variable(imgs.type(Tensor))
optimizer_D.zero_grad()
z = Variable(Tensor(np.random.normal(0, 1, (imgs.shape[0], latent_dim))))
gen_imgs = generator(z)
loss_real = adversarial_loss(discriminator(real_imgs), valid)
loss_fake = adversarial_loss(discriminator(gen_imgs.detach()), fake)
d_loss = (loss_real + loss_fake) / 2
d_loss.backward()
optimizer_D.step()
optimizer_G.zero_grad()
g_loss = adversarial_loss(discriminator(gen_imgs), valid)
g_loss.backward()
optimizer_G.step()
```
上述代码片段展示了如何创建基本的生成对抗网络框架[^1]。对于更复杂的应用场景,可以根据具体需求调整模型结构、增加额外组件或改进训练技巧以提高性能[^3]。
阅读全文