基于pytorch的输入1111(一通道灰度图像)的图片数据集文件夹,用gan生成扩充数据集,并保存模型和生成的图片,后加载模型使生成不同的1111大小的图像,代码
时间: 2024-03-16 20:42:04 浏览: 88
好的,您可以参考下面的代码实现:
首先,需要安装PyTorch和torchvision库。然后,导入必要的包和模块:
``` python
import torch
import torchvision
from torchvision import transforms
from torch.utils.data import DataLoader
import torch.nn as nn
import torch.optim as optim
import os
from PIL import Image
import numpy as np
import matplotlib.pyplot as plt
```
接下来,定义一个类用于加载数据集:
``` python
class ImageDataset(torch.utils.data.Dataset):
def __init__(self, root, transform=None):
self.root = root
self.transform = transform
self.images = os.listdir(root)
def __len__(self):
return len(self.images)
def __getitem__(self, index):
img_path = os.path.join(self.root, self.images[index])
image = Image.open(img_path).convert('L')
if self.transform is not None:
image = self.transform(image)
return image
```
然后,定义生成器和判别器的模型:
``` python
class Generator(nn.Module):
def __init__(self):
super(Generator, self).__init__()
self.main = nn.Sequential(
nn.ConvTranspose2d(100, 512, 4, 1, 0, bias=False),
nn.BatchNorm2d(512),
nn.ReLU(True),
nn.ConvTranspose2d(512, 256, 4, 2, 1, bias=False),
nn.BatchNorm2d(256),
nn.ReLU(True),
nn.ConvTranspose2d(256, 128, 4, 2, 1, bias=False),
nn.BatchNorm2d(128),
nn.ReLU(True),
nn.ConvTranspose2d(128, 1, 4, 2, 1, bias=False),
nn.Tanh()
)
def forward(self, x):
x = self.main(x)
return x
class Discriminator(nn.Module):
def __init__(self):
super(Discriminator, self).__init__()
self.main = nn.Sequential(
nn.Conv2d(1, 64, 4, 2, 1, bias=False),
nn.LeakyReLU(0.2, inplace=True),
nn.Conv2d(64, 128, 4, 2, 1, bias=False),
nn.BatchNorm2d(128),
nn.LeakyReLU(0.2, inplace=True),
nn.Conv2d(128, 256, 4, 2, 1, bias=False),
nn.BatchNorm2d(256),
nn.LeakyReLU(0.2, inplace=True),
nn.Conv2d(256, 512, 4, 2, 1, bias=False),
nn.BatchNorm2d(512),
nn.LeakyReLU(0.2, inplace=True),
nn.Conv2d(512, 1, 4, 1, 0, bias=False),
nn.Sigmoid()
)
def forward(self, x):
x = self.main(x)
return x.view(-1, 1)
```
接下来,定义超参数和数据集的路径:
``` python
batch_size = 64
lr = 0.0002
beta1 = 0.5
epochs = 50
z_dim = 100
image_size = 28
transform = transforms.Compose([
transforms.Resize(image_size),
transforms.ToTensor(),
transforms.Normalize((0.5,), (0.5,))
])
data_path = './data'
```
然后,加载数据集:
``` python
dataset = ImageDataset(data_path, transform)
dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True)
```
接着,初始化生成器和判别器的模型和损失函数:
``` python
generator = Generator()
discriminator = Discriminator()
criterion = nn.BCELoss()
optimizer_g = optim.Adam(generator.parameters(), lr=lr, betas=(beta1, 0.999))
optimizer_d = optim.Adam(discriminator.parameters(), lr=lr, betas=(beta1, 0.999))
```
然后,定义训练过程:
``` python
for epoch in range(epochs):
for i, data in enumerate(dataloader):
real = data
batch_size = real.size(0)
label_real = torch.full((batch_size,), 1, dtype=torch.float32)
label_fake = torch.full((batch_size,), 0, dtype=torch.float32)
# Train Discriminator
discriminator.zero_grad()
z = torch.randn(batch_size, z_dim, 1, 1)
fake = generator(z)
output_real = discriminator(real)
output_fake = discriminator(fake.detach())
loss_d = criterion(output_real, label_real) + criterion(output_fake, label_fake)
loss_d.backward()
optimizer_d.step()
# Train Generator
generator.zero_grad()
output_fake = discriminator(fake)
loss_g = criterion(output_fake, label_real)
loss_g.backward()
optimizer_g.step()
if i % 100 == 0:
print(f'[{epoch+1}/{epochs}] [{i}/{len(dataloader)}] Loss_D: {loss_d.item():.4f} Loss_G: {loss_g.item():.4f}')
# Save generated images
z = torch.randn(10, z_dim, 1, 1)
fake = generator(z)
fake_images = torchvision.utils.make_grid(fake, nrow=10, normalize=True)
plt.imshow(fake_images.permute(1, 2, 0))
plt.axis('off')
plt.savefig(f'./output/fake_images_{epoch+1}.png')
plt.close()
# Save models
torch.save(generator.state_dict(), f'./output/generator_{epoch+1}.pth')
torch.save(discriminator.state_dict(), f'./output/discriminator_{epoch+1}.pth')
```
最后,定义一个函数用于生成不同大小的图像:
``` python
def generate_image(model_path, image_size):
generator = Generator()
generator.load_state_dict(torch.load(model_path))
generator.eval()
z = torch.randn(1, z_dim, 1, 1)
fake = generator(z)
fake_image = fake.view(1, 1, image_size, image_size)
fake_image = (fake_image + 1) / 2
fake_image = fake_image.detach().numpy()[0][0]
return fake_image
```
您可以使用以上代码实现您的需求。
阅读全文