基于pytorch的输入1111(一通道灰度图像)的图片数据集文件夹,用gan生成扩充数据集,并保存模型和生成的图片,后加载模型使生成不同的1111大小的图像,代码

时间: 2024-03-16 17:42:04 浏览: 15
好的,您可以参考下面的代码实现: 首先,需要安装PyTorch和torchvision库。然后,导入必要的包和模块: ``` python import torch import torchvision from torchvision import transforms from torch.utils.data import DataLoader import torch.nn as nn import torch.optim as optim import os from PIL import Image import numpy as np import matplotlib.pyplot as plt ``` 接下来,定义一个类用于加载数据集: ``` python class ImageDataset(torch.utils.data.Dataset): def __init__(self, root, transform=None): self.root = root self.transform = transform self.images = os.listdir(root) def __len__(self): return len(self.images) def __getitem__(self, index): img_path = os.path.join(self.root, self.images[index]) image = Image.open(img_path).convert('L') if self.transform is not None: image = self.transform(image) return image ``` 然后,定义生成器和判别器的模型: ``` python class Generator(nn.Module): def __init__(self): super(Generator, self).__init__() self.main = nn.Sequential( nn.ConvTranspose2d(100, 512, 4, 1, 0, bias=False), nn.BatchNorm2d(512), nn.ReLU(True), nn.ConvTranspose2d(512, 256, 4, 2, 1, bias=False), nn.BatchNorm2d(256), nn.ReLU(True), nn.ConvTranspose2d(256, 128, 4, 2, 1, bias=False), nn.BatchNorm2d(128), nn.ReLU(True), nn.ConvTranspose2d(128, 1, 4, 2, 1, bias=False), nn.Tanh() ) def forward(self, x): x = self.main(x) return x class Discriminator(nn.Module): def __init__(self): super(Discriminator, self).__init__() self.main = nn.Sequential( nn.Conv2d(1, 64, 4, 2, 1, bias=False), nn.LeakyReLU(0.2, inplace=True), nn.Conv2d(64, 128, 4, 2, 1, bias=False), nn.BatchNorm2d(128), nn.LeakyReLU(0.2, inplace=True), nn.Conv2d(128, 256, 4, 2, 1, bias=False), nn.BatchNorm2d(256), nn.LeakyReLU(0.2, inplace=True), nn.Conv2d(256, 512, 4, 2, 1, bias=False), nn.BatchNorm2d(512), nn.LeakyReLU(0.2, inplace=True), nn.Conv2d(512, 1, 4, 1, 0, bias=False), nn.Sigmoid() ) def forward(self, x): x = self.main(x) return x.view(-1, 1) ``` 接下来,定义超参数和数据集的路径: ``` python batch_size = 64 lr = 0.0002 beta1 = 0.5 epochs = 50 z_dim = 100 image_size = 28 transform = transforms.Compose([ transforms.Resize(image_size), transforms.ToTensor(), transforms.Normalize((0.5,), (0.5,)) ]) data_path = './data' ``` 然后,加载数据集: ``` python dataset = ImageDataset(data_path, transform) dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True) ``` 接着,初始化生成器和判别器的模型和损失函数: ``` python generator = Generator() discriminator = Discriminator() criterion = nn.BCELoss() optimizer_g = optim.Adam(generator.parameters(), lr=lr, betas=(beta1, 0.999)) optimizer_d = optim.Adam(discriminator.parameters(), lr=lr, betas=(beta1, 0.999)) ``` 然后,定义训练过程: ``` python for epoch in range(epochs): for i, data in enumerate(dataloader): real = data batch_size = real.size(0) label_real = torch.full((batch_size,), 1, dtype=torch.float32) label_fake = torch.full((batch_size,), 0, dtype=torch.float32) # Train Discriminator discriminator.zero_grad() z = torch.randn(batch_size, z_dim, 1, 1) fake = generator(z) output_real = discriminator(real) output_fake = discriminator(fake.detach()) loss_d = criterion(output_real, label_real) + criterion(output_fake, label_fake) loss_d.backward() optimizer_d.step() # Train Generator generator.zero_grad() output_fake = discriminator(fake) loss_g = criterion(output_fake, label_real) loss_g.backward() optimizer_g.step() if i % 100 == 0: print(f'[{epoch+1}/{epochs}] [{i}/{len(dataloader)}] Loss_D: {loss_d.item():.4f} Loss_G: {loss_g.item():.4f}') # Save generated images z = torch.randn(10, z_dim, 1, 1) fake = generator(z) fake_images = torchvision.utils.make_grid(fake, nrow=10, normalize=True) plt.imshow(fake_images.permute(1, 2, 0)) plt.axis('off') plt.savefig(f'./output/fake_images_{epoch+1}.png') plt.close() # Save models torch.save(generator.state_dict(), f'./output/generator_{epoch+1}.pth') torch.save(discriminator.state_dict(), f'./output/discriminator_{epoch+1}.pth') ``` 最后,定义一个函数用于生成不同大小的图像: ``` python def generate_image(model_path, image_size): generator = Generator() generator.load_state_dict(torch.load(model_path)) generator.eval() z = torch.randn(1, z_dim, 1, 1) fake = generator(z) fake_image = fake.view(1, 1, image_size, image_size) fake_image = (fake_image + 1) / 2 fake_image = fake_image.detach().numpy()[0][0] return fake_image ``` 您可以使用以上代码实现您的需求。

相关推荐

最新推荐

recommend-type

Pytorch使用MNIST数据集实现CGAN和生成指定的数字方式

今天小编就为大家分享一篇Pytorch使用MNIST数据集实现CGAN和生成指定的数字方式,具有很好的参考价值,希望对大家有所帮助。一起跟随小编过来看看吧
recommend-type

pytorch学习教程之自定义数据集

在pytorch中,提供了一些接口和类,方便我们定义自己的数据集合,下面完整的试验自定义样本集的整个流程。 开发环境 Ubuntu 18.04 pytorch 1.0 pycharm 实验目的 掌握pytorch中数据集相关的API接口和类 熟悉...
recommend-type

pytorch 语义分割-医学图像-脑肿瘤数据集的载入模块

由于最近目标是完成基于深度学习的脑肿瘤语义分割实验,所以需要用到自定义的数据载入,本文参考了一下博客:https://blog.csdn.net/tuiqdymy/article/details/84779716?utm_source=app,一开始是做的眼底图像分割,...
recommend-type

pytorch实现mnist数据集的图像可视化及保存

今天小编就为大家分享一篇pytorch实现mnist数据集的图像可视化及保存,具有很好的参考价值,希望对大家有所帮助。一起跟随小编过来看看吧
recommend-type

PyTorch版YOLOv4训练自己的数据集—基于Google Colab

Google Colaboratory是谷歌开放的一款研究工具,主要用于机器学习的开发和研究。 工具优势:Google Colab最大的好处是给广大的AI开发者提供了免费的GPU使用。你可以在上面轻松地跑例如:Keras、Tensorflow、Pytorch...
recommend-type

zigbee-cluster-library-specification

最新的zigbee-cluster-library-specification说明文档。
recommend-type

管理建模和仿真的文件

管理Boualem Benatallah引用此版本:布阿利姆·贝纳塔拉。管理建模和仿真。约瑟夫-傅立叶大学-格勒诺布尔第一大学,1996年。法语。NNT:电话:00345357HAL ID:电话:00345357https://theses.hal.science/tel-003453572008年12月9日提交HAL是一个多学科的开放存取档案馆,用于存放和传播科学研究论文,无论它们是否被公开。论文可以来自法国或国外的教学和研究机构,也可以来自公共或私人研究中心。L’archive ouverte pluridisciplinaire
recommend-type

实现实时数据湖架构:Kafka与Hive集成

![实现实时数据湖架构:Kafka与Hive集成](https://img-blog.csdnimg.cn/img_convert/10eb2e6972b3b6086286fc64c0b3ee41.jpeg) # 1. 实时数据湖架构概述** 实时数据湖是一种现代数据管理架构,它允许企业以低延迟的方式收集、存储和处理大量数据。与传统数据仓库不同,实时数据湖不依赖于预先定义的模式,而是采用灵活的架构,可以处理各种数据类型和格式。这种架构为企业提供了以下优势: - **实时洞察:**实时数据湖允许企业访问最新的数据,从而做出更明智的决策。 - **数据民主化:**实时数据湖使各种利益相关者都可
recommend-type

SPDK_NVMF_DISCOVERY_NQN是什么 有什么作用

SPDK_NVMF_DISCOVERY_NQN 是 SPDK (Storage Performance Development Kit) 中用于查询 NVMf (Non-Volatile Memory express over Fabrics) 存储设备名称的协议。NVMf 是一种基于网络的存储协议,可用于连接远程非易失性内存存储器。 SPDK_NVMF_DISCOVERY_NQN 的作用是让存储应用程序能够通过 SPDK 查询 NVMf 存储设备的名称,以便能够访问这些存储设备。通过查询 NVMf 存储设备名称,存储应用程序可以获取必要的信息,例如存储设备的IP地址、端口号、名称等,以便能
recommend-type

JSBSim Reference Manual

JSBSim参考手册,其中包含JSBSim简介,JSBSim配置文件xml的编写语法,编程手册以及一些应用实例等。其中有部分内容还没有写完,估计有生之年很难看到完整版了,但是内容还是很有参考价值的。