使用DCGAN训练自己的数据集,pytorch
时间: 2023-09-12 16:11:17 浏览: 190
Pytorch使用MNIST数据集实现基础GAN和DCGAN详解
以下是使用Pytorch训练自己的数据集的DCGAN示例代码:
```python
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
from torchvision.datasets import ImageFolder
from torchvision.transforms import transforms
from torchvision.utils import save_image
# 设置随机种子
torch.manual_seed(42)
# 定义生成器网络
class Generator(nn.Module):
def __init__(self, latent_dim):
super(Generator, self).__init__()
self.latent_dim = latent_dim
self.model = nn.Sequential(
nn.ConvTranspose2d(latent_dim, 256, kernel_size=4, stride=1, padding=0, bias=False),
nn.BatchNorm2d(256),
nn.ReLU(inplace=True),
nn.ConvTranspose2d(256, 128, kernel_size=4, stride=2, padding=1, bias=False),
nn.BatchNorm2d(128),
nn.ReLU(inplace=True),
nn.ConvTranspose2d(128, 64, kernel_size=4, stride=2, padding=1, bias=False),
nn.BatchNorm2d(64),
nn.ReLU(inplace=True),
nn.ConvTranspose2d(64, 3, kernel_size=4, stride=2, padding=1, bias=False),
nn.Tanh()
)
def forward(self, z):
z = z.view(-1, self.latent_dim, 1, 1)
img = self.model(z)
return img
# 定义判别器网络
class Discriminator(nn.Module):
def __init__(self):
super(Discriminator, self).__init__()
self.model = nn.Sequential(
nn.Conv2d(3, 64, kernel_size=4, stride=2, padding=1, bias=False),
nn.LeakyReLU(0.2, inplace=True),
nn.Conv2d(64, 128, kernel_size=4, stride=2, padding=1, bias=False),
nn.BatchNorm2d(128),
nn.LeakyReLU(0.2, inplace=True),
nn.Conv2d(128, 256, kernel_size=4, stride=2, padding=1, bias=False),
nn.BatchNorm2d(256),
nn.LeakyReLU(0.2, inplace=True),
nn.Conv2d(256, 1, kernel_size=4, stride=1, padding=0, bias=False),
nn.Sigmoid()
)
def forward(self, img):
validity = self.model(img)
return validity.view(-1, 1)
# 定义训练函数
def train(generator, discriminator, dataloader, epochs, device, latent_dim, lr):
# 定义损失函数和优化器
criterion = nn.BCELoss()
optimizer_G = optim.Adam(generator.parameters(), lr=lr, betas=(0.5, 0.999))
optimizer_D = optim.Adam(discriminator.parameters(), lr=lr, betas=(0.5, 0.999))
for epoch in range(epochs):
for i, (imgs, _) in enumerate(dataloader):
real_imgs = imgs.to(device)
# 训练判别器
optimizer_D.zero_grad()
# 计算真实图像的损失
real_labels = torch.ones((real_imgs.size(0), 1)).to(device)
real_loss = criterion(discriminator(real_imgs), real_labels)
# 计算生成器生成的图像的损失
z = torch.randn((real_imgs.size(0), latent_dim)).to(device)
fake_imgs = generator(z)
fake_labels = torch.zeros((real_imgs.size(0), 1)).to(device)
fake_loss = criterion(discriminator(fake_imgs.detach()), fake_labels)
# 计算总体损失并进行反向传播和优化
d_loss = (real_loss + fake_loss) / 2
d_loss.backward()
optimizer_D.step()
# 训练生成器
optimizer_G.zero_grad()
# 计算生成器生成的图像的损失
z = torch.randn((real_imgs.size(0), latent_dim)).to(device)
fake_imgs = generator(z)
real_labels = torch.ones((real_imgs.size(0), 1)).to(device)
g_loss = criterion(discriminator(fake_imgs), real_labels)
# 进行反向传播和优化
g_loss.backward()
optimizer_G.step()
# 输出训练信息
if i % 100 == 0:
print("[Epoch %d/%d] [Batch %d/%d] [D loss: %f] [G loss: %f]" % (epoch+1, epochs, i, len(dataloader), d_loss.item(), g_loss.item()))
# 保存生成的图像
save_image(fake_imgs[:25], "images/%d.png" % epoch, nrow=5, normalize=True)
# 设置超参数
latent_dim = 100
batch_size = 64
epochs = 50
lr = 0.0002
image_size = 64
image_channels = 3
# 设置数据预处理
transforms = transforms.Compose([
transforms.Resize(image_size),
transforms.CenterCrop(image_size),
transforms.ToTensor(),
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])
# 加载数据集
dataset = ImageFolder(root="data", transform=transforms)
dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True)
# 创建生成器和判别器
generator = Generator(latent_dim).to(device)
discriminator = Discriminator().to(device)
# 训练模型
train(generator, discriminator, dataloader, epochs, device, latent_dim, lr)
```
这里的`data`目录下应该包含你自己的图像数据集。你可以自己定义数据预处理方式,然后使用`ImageFolder`和`DataLoader`来加载数据集。在训练过程中,我们会将生成的图像保存在`images`目录下,你可以在训练过程中观察它们的变化。最后,训练结束后,你可以从生成器中生成一些新的图像,然后保存它们,以便进一步的可视化和分析。
阅读全文