重新写一份pytorchGAN模型代码
时间: 2023-12-15 21:25:19 浏览: 117
基于pytorch的GAN.py代码
以下是一个简单的PyTorch GAN模型代码,用于生成手写数字图像:
```python
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision.datasets as dset
import torchvision.transforms as transforms
from torch.utils.data import DataLoader
from torch.autograd import Variable
# 定义生成器
class Generator(nn.Module):
def __init__(self, ngf=64):
super(Generator, self).__init__()
self.main = nn.Sequential(
nn.ConvTranspose2d(100, ngf * 8, 4, 1, 0, bias=False),
nn.BatchNorm2d(ngf * 8),
nn.ReLU(True),
nn.ConvTranspose2d(ngf * 8, ngf * 4, 4, 2, 1, bias=False),
nn.BatchNorm2d(ngf * 4),
nn.ReLU(True),
nn.ConvTranspose2d(ngf * 4, ngf * 2, 4, 2, 1, bias=False),
nn.BatchNorm2d(ngf * 2),
nn.ReLU(True),
nn.ConvTranspose2d(ngf * 2, 1, 4, 2, 1, bias=False),
nn.Tanh()
)
def forward(self, input):
output = self.main(input)
return output
# 定义判别器
class Discriminator(nn.Module):
def __init__(self, ndf=64):
super(Discriminator, self).__init__()
self.main = nn.Sequential(
nn.Conv2d(1, ndf, 4, 2, 1, bias=False),
nn.LeakyReLU(0.2, inplace=True),
nn.Conv2d(ndf, ndf * 2, 4, 2, 1, bias=False),
nn.BatchNorm2d(ndf * 2),
nn.LeakyReLU(0.2, inplace=True),
nn.Conv2d(ndf * 2, ndf * 4, 4, 2, 1, bias=False),
nn.BatchNorm2d(ndf * 4),
nn.LeakyReLU(0.2, inplace=True),
nn.Conv2d(ndf * 4, ndf * 8, 4, 2, 1, bias=False),
nn.BatchNorm2d(ndf * 8),
nn.LeakyReLU(0.2, inplace=True),
nn.Conv2d(ndf * 8, 1, 4, 1, 0, bias=False),
nn.Sigmoid()
)
def forward(self, input):
output = self.main(input)
return output.view(-1, 1).squeeze(1)
# 定义训练函数
def train():
# 定义一些超参数
batch_size = 64
learning_rate = 0.0002
num_epochs = 100
beta1 = 0.5
# 定义数据集
dataset = dset.MNIST(root='./data', train=True, transform=transforms.Compose([
transforms.Resize(64),
transforms.ToTensor(),
transforms.Normalize((0.5,), (0.5,))
]), download=True)
dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True)
# 定义生成器和判别器
netG = Generator()
netD = Discriminator()
# 定义损失函数和优化器
criterion = nn.BCELoss()
optimizerD = optim.Adam(netD.parameters(), lr=learning_rate, betas=(beta1, 0.999))
optimizerG = optim.Adam(netG.parameters(), lr=learning_rate, betas=(beta1, 0.999))
# 开始训练
for epoch in range(num_epochs):
for i, data in enumerate(dataloader):
# 训练判别器
netD.zero_grad()
real, _ = data
input = Variable(real)
target = Variable(torch.ones(input.size()[0]))
output = netD(input)
errD_real = criterion(output, target)
errD_real.backward()
noise = Variable(torch.randn(input.size()[0], 100, 1, 1))
fake = netG(noise)
target = Variable(torch.zeros(input.size()[0]))
output = netD(fake.detach())
errD_fake = criterion(output, target)
errD_fake.backward()
errD = errD_real + errD_fake
optimizerD.step()
# 训练生成器
netG.zero_grad()
target = Variable(torch.ones(input.size()[0]))
output = netD(fake)
errG = criterion(output, target)
errG.backward()
optimizerG.step()
# 输出损失值
print('[%d/%d][%d/%d] Loss_D: %.4f Loss_G: %.4f' % (epoch, num_epochs, i, len(dataloader),
errD.item(), errG.item()))
# 保存模型
torch.save(netG.state_dict(), 'netG.pth')
torch.save(netD.state_dict(), 'netD.pth')
if __name__ == '__main__':
train()
```
这个模型使用了一个简单的生成器和判别器,用于生成手写数字图像。在训练过程中,先训练判别器,再训练生成器,以此来提高生成器的生成质量。在训练过程中,会输出每一次迭代的损失值,以便进行模型调整和优化。
阅读全文