基于pytorch的输入1111(一通道灰度图像)的图片数据集文件夹,用gan生成扩充数据集,并保存模型,后加载模型使生成不同的1111大小的图像,代码
时间: 2024-03-16 16:41:49 浏览: 16
好的,下面是基于 PyTorch 实现的代码:
```python
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms, utils
from PIL import Image
import os
# 定义生成器
class Generator(nn.Module):
def __init__(self):
super(Generator, self).__init__()
self.fc = nn.Sequential(
nn.Linear(100, 128 * 4 * 4),
nn.BatchNorm1d(128 * 4 * 4),
nn.ReLU(inplace=True),
)
self.conv = nn.Sequential(
nn.ConvTranspose2d(128, 64, kernel_size=4, stride=2, padding=1),
nn.BatchNorm2d(64),
nn.ReLU(inplace=True),
nn.ConvTranspose2d(64, 1, kernel_size=4, stride=2, padding=1),
nn.Tanh(),
)
def forward(self, x):
x = self.fc(x)
x = x.view(-1, 128, 4, 4)
x = self.conv(x)
return x
# 定义判别器
class Discriminator(nn.Module):
def __init__(self):
super(Discriminator, self).__init__()
self.conv = nn.Sequential(
nn.Conv2d(1, 64, kernel_size=4, stride=2, padding=1),
nn.BatchNorm2d(64),
nn.LeakyReLU(0.2, inplace=True),
nn.Conv2d(64, 128, kernel_size=4, stride=2, padding=1),
nn.BatchNorm2d(128),
nn.LeakyReLU(0.2, inplace=True),
)
self.fc = nn.Sequential(
nn.Linear(128 * 4 * 4, 1),
nn.Sigmoid(),
)
def forward(self, x):
x = self.conv(x)
x = x.view(-1, 128 * 4 * 4)
x = self.fc(x)
return x
# 定义数据集类
class ImageDataset(Dataset):
def __init__(self, root_dir, transform=None):
self.root_dir = root_dir
self.transform = transform
self.file_list = os.listdir(root_dir)
def __len__(self):
return len(self.file_list)
def __getitem__(self, idx):
img_name = os.path.join(self.root_dir, self.file_list[idx])
image = Image.open(img_name).convert('L')
if self.transform:
image = self.transform(image)
return image
# 定义超参数
batch_size = 64
lr = 0.0002
beta1 = 0.5
beta2 = 0.999
num_epochs = 100
latent_size = 100
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
# 定义数据预处理
transform = transforms.Compose([
transforms.Resize((11, 11)),
transforms.ToTensor(),
transforms.Normalize(mean=[0.5], std=[0.5]),
])
# 加载数据集
dataset = ImageDataset(root_dir='data', transform=transform)
dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True)
# 定义生成器和判别器
G = Generator().to(device)
D = Discriminator().to(device)
# 定义损失函数和优化器
criterion = nn.BCELoss()
optimizer_G = optim.Adam(G.parameters(), lr=lr, betas=(beta1, beta2))
optimizer_D = optim.Adam(D.parameters(), lr=lr, betas=(beta1, beta2))
# 训练模型
for epoch in range(num_epochs):
for i, data in enumerate(dataloader):
# 训练判别器
D.zero_grad()
real_images = data.to(device)
real_labels = torch.ones(real_images.size(0)).to(device)
noise = torch.randn(real_images.size(0), latent_size).to(device)
fake_images = G(noise)
fake_labels = torch.zeros(fake_images.size(0)).to(device)
real_outputs = D(real_images)
fake_outputs = D(fake_images.detach())
d_loss = criterion(real_outputs, real_labels) + criterion(fake_outputs, fake_labels)
d_loss.backward()
optimizer_D.step()
# 训练生成器
G.zero_grad()
noise = torch.randn(real_images.size(0), latent_size).to(device)
fake_images = G(noise)
fake_outputs = D(fake_images)
g_loss = criterion(fake_outputs, real_labels)
g_loss.backward()
optimizer_G.step()
# 输出当前的损失
if (i+1) % 10 == 0:
print('[Epoch %d/%d] [Batch %d/%d] [D loss: %.4f] [G loss: %.4f]'
% (epoch+1, num_epochs, i+1, len(dataloader), d_loss.item(), g_loss.item()))
# 保存模型
torch.save(G.state_dict(), 'G.pth')
torch.save(D.state_dict(), 'D.pth')
# 加载模型并生成图像
G = Generator().to(device)
G.load_state_dict(torch.load('G.pth'))
G.eval()
with torch.no_grad():
noise = torch.randn(1, latent_size).to(device)
fake_images = G(noise)
fake_images = ((fake_images + 1) / 2).cpu().numpy()
fake_images = fake_images.reshape((11, 11)) * 255
fake_images = Image.fromarray(fake_images.astype('uint8')).convert('RGB')
fake_images.show()
```
这段代码实现了一个基于 PyTorch 的 GAN 模型,用于生成扩充数据集。在训练过程中,使用了 DataLoader 类来加载数据集,并使用了交叉熵损失函数和 Adam 优化器来训练模型。在训练完成后,将生成器和判别器保存为模型文件,以便后续