cgan代码
时间: 2023-07-02 09:13:17 浏览: 136
以下是一个简单的CGAN的代码示例:
```python
import torch
import torch.nn as nn
import numpy as np
# 定义 generator 和 discriminator 的网络结构
class Generator(nn.Module):
def __init__(self, input_dim, output_dim, hidden_dim):
super(Generator, self).__init__()
self.fc1 = nn.Linear(input_dim, hidden_dim)
self.fc2 = nn.Linear(hidden_dim, hidden_dim)
self.fc3 = nn.Linear(hidden_dim, output_dim)
def forward(self, x, y):
z = torch.cat([x, y], 1)
z = nn.functional.relu(self.fc1(z))
z = nn.functional.relu(self.fc2(z))
z = torch.tanh(self.fc3(z))
return z
class Discriminator(nn.Module):
def __init__(self, input_dim, hidden_dim):
super(Discriminator, self).__init__()
self.fc1 = nn.Linear(input_dim, hidden_dim)
self.fc2 = nn.Linear(hidden_dim, hidden_dim)
self.fc3 = nn.Linear(hidden_dim, 1)
def forward(self, x, y):
z = torch.cat([x, y], 1)
z = nn.functional.relu(self.fc1(z))
z = nn.functional.relu(self.fc2(z))
z = torch.sigmoid(self.fc3(z))
return z
# 定义损失函数和优化器
criterion = nn.BCELoss() # 二分类交叉熵损失函数
G_optimizer = torch.optim.Adam(generator.parameters(), lr=0.0002, betas=(0.5, 0.999))
D_optimizer = torch.optim.Adam(discriminator.parameters(), lr=0.0002, betas=(0.5, 0.999))
# 定义训练函数
def train_GAN(num_epochs, data_loader):
for epoch in range(num_epochs):
for i, (real_data, real_label) in enumerate(data_loader):
# 训练 discriminator
D_optimizer.zero_grad()
fake_label = torch.zeros(real_label.shape[0], 1)
real_label = real_label.float().view(-1, 1)
real_data = real_data.view(-1, input_dim)
real_decision = discriminator(real_data, real_label)
D_real_loss = criterion(real_decision, real_label)
fake_data = generator(torch.randn(real_data.shape[0], z_dim), real_label)
fake_decision = discriminator(fake_data, fake_label)
D_fake_loss = criterion(fake_decision, fake_label)
D_loss = D_real_loss + D_fake_loss
D_loss.backward()
D_optimizer.step()
# 训练 generator
G_optimizer.zero_grad()
fake_label = torch.ones(real_label.shape[0], 1)
fake_data = generator(torch.randn(real_data.shape[0], z_dim), real_label)
fake_decision = discriminator(fake_data, fake_label)
G_loss = criterion(fake_decision, fake_label)
G_loss.backward()
G_optimizer.step()
# 打印训练信息
if (i+1) % 20 == 0:
print("Epoch [{}/{}], Step [{}/{}], Discriminator Loss: {:.4f}, Generator Loss: {:.4f}"
.format(epoch+1, num_epochs, i+1, len(data_loader), D_loss.item(), G_loss.item()))
# 定义数据集和参数
input_dim = 2 # 输入数据维度
output_dim = 2 # 输出数据维度
hidden_dim = 128 # 隐藏层维度
z_dim = 10 # 随机噪声维度
batch_size = 64 # 每批次训练数据量
# 准备数据集
data = np.random.multivariate_normal([0,0], [[1,0],[0,1]], size=1000) # 生成1000个二维高斯分布的数据
label = np.zeros((1000, 1)) # 标签全为0,表示数据是真实数据
dataset = torch.utils.data.TensorDataset(torch.tensor(data), torch.tensor(label))
data_loader = torch.utils.data.DataLoader(dataset, batch_size=batch_size, shuffle=True)
# 初始化 generator 和 discriminator
generator = Generator(input_dim+z_dim, output_dim, hidden_dim)
discriminator = Discriminator(input_dim+1, hidden_dim)
# 训练 GAN
num_epochs = 50
train_GAN(num_epochs, data_loader)
```
这个CGAN的代码中,Generator和Discriminator的网络结构都比较简单,只有3层全连接层。在训练过程中,我们先训练Discriminator,然后再训练Generator,交替进行,期望通过这个过程让Generator生成的假数据越来越逼近真实数据的分布。
阅读全文