基于pytorch的输入为11*11灰度图文件夹,用dcgan生成更多11*11来扩充数据集 代码
时间: 2023-11-30 18:03:04 浏览: 127
好的,以下是基于PyTorch的DCGAN生成器的代码,可以根据11x11的灰度图来生成更多的11x11灰度图来扩充数据集:
```python
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision.utils import save_image
import torchvision.transforms as transforms
from torch.utils.data import Dataset, DataLoader
from PIL import Image
import os
# 定义生成器(Generator)和判别器(Discriminator)
class Generator(nn.Module):
def __init__(self, nz=100, ngf=64, nc=1):
super(Generator, self).__init__()
self.main = nn.Sequential(
nn.ConvTranspose2d(nz, ngf * 8, 4, 1, 0, bias=False),
nn.BatchNorm2d(ngf * 8),
nn.ReLU(True),
nn.ConvTranspose2d(ngf * 8, ngf * 4, 4, 2, 1, bias=False),
nn.BatchNorm2d(ngf * 4),
nn.ReLU(True),
nn.ConvTranspose2d(ngf * 4, ngf * 2, 4, 2, 1, bias=False),
nn.BatchNorm2d(ngf * 2),
nn.ReLU(True),
nn.ConvTranspose2d(ngf * 2, ngf, 4, 2, 1, bias=False),
nn.BatchNorm2d(ngf),
nn.ReLU(True),
nn.ConvTranspose2d(ngf, nc, 4, 2, 1, bias=False),
nn.Tanh()
)
def forward(self, input):
return self.main(input)
class Discriminator(nn.Module):
def __init__(self, ndf=64, nc=1):
super(Discriminator, self).__init__()
self.main = nn.Sequential(
nn.Conv2d(nc, ndf, 4, 2, 1, bias=False),
nn.LeakyReLU(0.2, inplace=True),
nn.Conv2d(ndf, ndf * 2, 4, 2, 1, bias=False),
nn.BatchNorm2d(ndf * 2),
nn.LeakyReLU(0.2, inplace=True),
nn.Conv2d(ndf * 2, ndf * 4, 4, 2, 1, bias=False),
nn.BatchNorm2d(ndf * 4),
nn.LeakyReLU(0.2, inplace=True),
nn.Conv2d(ndf * 4, ndf * 8, 4, 2, 1, bias=False),
nn.BatchNorm2d(ndf * 8),
nn.LeakyReLU(0.2, inplace=True),
nn.Conv2d(ndf * 8, 1, 4, 1, 0, bias=False),
nn.Sigmoid()
)
def forward(self, input):
return self.main(input).view(-1, 1)
# 定义训练函数
def train(dataloader, generator, discriminator, optimizer_G, optimizer_D, device, epochs, save_interval):
# 定义损失函数
criterion = nn.BCELoss()
# 训练过程
for epoch in range(epochs):
for i, data in enumerate(dataloader, 0):
real_imgs = data[0].to(device)
batch_size = real_imgs.size(0)
# 训练判别器(Discriminator)
optimizer_D.zero_grad()
# 准备真实样本
real_label = torch.full((batch_size,), 1, dtype=torch.float, device=device)
# 计算真实样本的损失
output = discriminator(real_imgs).view(-1)
errD_real = criterion(output, real_label)
# 生成假样本
noise = torch.randn(batch_size, 100, 1, 1, device=device)
fake_imgs = generator(noise)
# 准备假样本
fake_label = torch.full((batch_size,), 0, dtype=torch.float, device=device)
# 计算假样本的损失
output = discriminator(fake_imgs.detach()).view(-1)
errD_fake = criterion(output, fake_label)
# 计算判别器的总损失
errD = errD_real + errD_fake
# 反向传播更新判别器
errD.backward()
optimizer_D.step()
# 训练生成器(Generator)
optimizer_G.zero_grad()
fake_label.fill_(1)
# 重新计算假样本的损失,此处不修改判别器的参数
output = discriminator(fake_imgs).view(-1)
errG = criterion(output, fake_label)
# 反向传播更新生成器
errG.backward()
optimizer_G.step()
# 输出训练结果
if i % 50 == 0:
print('[%d/%d][%d/%d]\tLoss_D: %.4f\tLoss_G: %.4f'
% (epoch, epochs, i, len(dataloader), errD.item(), errG.item()))
# 保存生成器的输出
batches_done = epoch * len(dataloader) + i
if batches_done % save_interval == 0:
save_image(fake_imgs.data[:25], 'output/%d.png' % batches_done, nrow=5, normalize=True)
# 定义数据类(Data Class)
class ImageDataset(Dataset):
def __init__(self, root, transform=None):
self.transform = transform
self.image_paths = [os.path.join(root, x) for x in os.listdir(root)]
def __getitem__(self, index):
image_path = self.image_paths[index]
image = Image.open(image_path).convert('L')
if self.transform is not None:
image = self.transform(image)
return image
def __len__(self):
return len(self.image_paths)
# 设置一些超参数
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
batch_size = 64
lr = 0.0002
betas = (0.5, 0.999)
epochs = 50
nz = 100
ngf = 64
ndf = 64
nc = 1
save_interval = 500
# 加载数据集
transform = transforms.Compose([
transforms.Resize(11),
transforms.ToTensor(),
transforms.Normalize((0.5,), (0.5,))
])
dataset = ImageDataset(root='data_folder_path', transform=transform)
dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True, num_workers=2)
# 实例化生成器和判别器,并将它们移到设备上
generator = Generator(nz=nz, ngf=ngf, nc=nc).to(device)
discriminator = Discriminator(ndf=ndf, nc=nc).to(device)
# 定义优化器
optimizer_G = optim.Adam(generator.parameters(), lr=lr, betas=betas)
optimizer_D = optim.Adam(discriminator.parameters(), lr=lr, betas=betas)
# 训练模型
train(dataloader, generator, discriminator, optimizer_G, optimizer_D, device, epochs, save_interval)
```
其中,'data_folder_path'需要替换为存放11x11灰度图的文件夹的路径。训练完成后,生成的11x11灰度图将保存在output文件夹下。
阅读全文