基于pytorch的输入1111(一通道灰度图像)的图片数据集文件夹,用gan生成扩充数据集,并保存模型和生成的图片,后加载模型使生成不同的1111大小的图像,代码
时间: 2024-03-16 21:42:18 浏览: 127
python扩充图片数据集
好的,你的问题是如何用基于 PyTorch 的 GAN 生成 1111 大小的灰度图片数据集,并保存模型和生成的图片,以及如何加载模型生成不同的 1111 大小的图像。下面是代码实现:
```python
import os
import torch
import torchvision
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
from torch.autograd import Variable
# 定义生成器和判别器网络
class Generator(nn.Module):
def __init__(self, input_size, num_channels, output_size):
super(Generator, self).__init__()
self.input_size = input_size
self.num_channels = num_channels
self.output_size = output_size
self.fc1 = nn.Linear(input_size, 256)
self.fc2 = nn.Linear(256, 512)
self.fc3 = nn.Linear(512, 1024)
self.fc4 = nn.Linear(1024, output_size * output_size * num_channels)
self.relu = nn.ReLU()
self.tanh = nn.Tanh()
self.bn1 = nn.BatchNorm1d(256)
self.bn2 = nn.BatchNorm1d(512)
self.bn3 = nn.BatchNorm1d(1024)
def forward(self, x):
x = self.fc1(x)
x = self.bn1(x)
x = self.relu(x)
x = self.fc2(x)
x = self.bn2(x)
x = self.relu(x)
x = self.fc3(x)
x = self.bn3(x)
x = self.relu(x)
x = self.fc4(x)
x = self.tanh(x)
x = x.view(-1, self.num_channels, self.output_size, self.output_size)
return x
class Discriminator(nn.Module):
def __init__(self, input_size, num_channels, output_size):
super(Discriminator, self).__init__()
self.input_size = input_size
self.num_channels = num_channels
self.output_size = output_size
self.conv1 = nn.Conv2d(num_channels, 64, 4, 2, 1)
self.conv2 = nn.Conv2d(64, 128, 4, 2, 1)
self.conv3 = nn.Conv2d(128, 256, 4, 2, 1)
self.conv4 = nn.Conv2d(256, 1, 4, 1, 0)
self.relu = nn.LeakyReLU(0.2, inplace=True)
self.sigmoid = nn.Sigmoid()
def forward(self, x):
x = self.conv1(x)
x = self.relu(x)
x = self.conv2(x)
x = self.relu(x)
x = self.conv3(x)
x = self.relu(x)
x = self.conv4(x)
x = self.sigmoid(x)
x = x.view(-1, 1)
return x
# 定义超参数
batch_size = 128
num_epochs = 100
learning_rate = 0.0002
betas = (0.5, 0.999)
# 定义数据集和数据加载器
transform = transforms.Compose([
transforms.Resize(64),
transforms.CenterCrop(64),
transforms.ToTensor(),
transforms.Normalize(mean=[0.5], std=[0.5])
])
train_dataset = datasets.ImageFolder('data', transform=transform)
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
# 定义生成器和判别器
generator = Generator(input_size=100, num_channels=1, output_size=64)
discriminator = Discriminator(input_size=64, num_channels=1, output_size=1)
# 定义损失函数和优化器
criterion = nn.BCELoss()
optimizer_g = optim.Adam(generator.parameters(), lr=learning_rate, betas=betas)
optimizer_d = optim.Adam(discriminator.parameters(), lr=learning_rate, betas=betas)
# 定义训练函数
def train(num_epochs, generator, discriminator, optimizer_g, optimizer_d, train_loader):
for epoch in range(num_epochs):
for i, (images, _) in enumerate(train_loader):
# 定义标签
real_labels = Variable(torch.ones(images.size(0)))
fake_labels = Variable(torch.zeros(images.size(0)))
# 训练判别器
discriminator.zero_grad()
outputs = discriminator(images)
real_loss = criterion(outputs, real_labels)
real_score = outputs
z = Variable(torch.randn(images.size(0), 100))
fake_images = generator(z)
outputs = discriminator(fake_images)
fake_loss = criterion(outputs, fake_labels)
fake_score = outputs
d_loss = real_loss + fake_loss
d_loss.backward()
optimizer_d.step()
# 训练生成器
generator.zero_grad()
z = Variable(torch.randn(images.size(0), 100))
fake_images = generator(z)
outputs = discriminator(fake_images)
g_loss = criterion(outputs, real_labels)
g_loss.backward()
optimizer_g.step()
# 打印损失
if (i+1) % 100 == 0:
print("Epoch [{}/{}], Step [{}/{}], d_loss: {:.4f}, g_loss: {:.4f}, D(x): {:.2f}, D(G(z)): {:.2f}"
.format(epoch+1, num_epochs, i+1, len(train_loader), d_loss.item(), g_loss.item(),
real_score.mean().item(), fake_score.mean().item()))
# 保存模型和生成的图片
if (epoch+1) % 10 == 0:
if not os.path.exists('models'):
os.makedirs('models')
if not os.path.exists('results'):
os.makedirs('results')
torch.save(generator.state_dict(), 'models/generator-{}.ckpt'.format(epoch+1))
z = Variable(torch.randn(64, 100))
generator.eval()
fake_images = generator(z)
torchvision.utils.save_image(fake_images.data, 'results/fake_images-{}.png'.format(epoch+1), nrow=8, normalize=True)
generator.train()
# 训练模型
train(num_epochs, generator, discriminator, optimizer_g, optimizer_d, train_loader)
# 加载模型并生成图片
def generate_images(generator, model_path, image_size):
generator.load_state_dict(torch.load(model_path))
generator.eval()
z = Variable(torch.randn(64, 100))
fake_images = generator(z)
torchvision.utils.save_image(fake_images.data, 'results/generated_images.png', nrow=8, normalize=True)
generator.train()
generate_images(generator, 'models/generator-100.ckpt', 1111)
```
以上代码实现了一个基于 PyTorch 的 GAN,可以用来生成 64x64 的灰度图片数据集。其中,`Generator` 和 `Discriminator` 分别是生成器和判别器网络,`train` 函数是训练函数,`generate_images` 函数是用来加载模型并生成图片的函数。在 `train` 函数中,我们先训练判别器,然后训练生成器,最后保存模型和生成的图片。在 `generate_images` 函数中,我们加载训练好的模型,并生成 64x64 的灰度图片数据集。
如果要生成不同大小的灰度图片数据集,只需要更改生成器和判别器的输入大小,并修改 `generate_images` 函数中的 `z` 大小和保存的文件名即可。
阅读全文