基于pytorch的输入1111(一通道灰度图像)的图片数据集文件夹,用gan生成扩充数据集,并保存模型,后加载模型使生成不同的1111大小的图像,代码

时间: 2024-03-16 16:41:49 浏览: 16
好的,下面是基于 PyTorch 实现的代码: ```python import torch import torch.nn as nn import torch.optim as optim from torch.utils.data import Dataset, DataLoader from torchvision import transforms, utils from PIL import Image import os # 定义生成器 class Generator(nn.Module): def __init__(self): super(Generator, self).__init__() self.fc = nn.Sequential( nn.Linear(100, 128 * 4 * 4), nn.BatchNorm1d(128 * 4 * 4), nn.ReLU(inplace=True), ) self.conv = nn.Sequential( nn.ConvTranspose2d(128, 64, kernel_size=4, stride=2, padding=1), nn.BatchNorm2d(64), nn.ReLU(inplace=True), nn.ConvTranspose2d(64, 1, kernel_size=4, stride=2, padding=1), nn.Tanh(), ) def forward(self, x): x = self.fc(x) x = x.view(-1, 128, 4, 4) x = self.conv(x) return x # 定义判别器 class Discriminator(nn.Module): def __init__(self): super(Discriminator, self).__init__() self.conv = nn.Sequential( nn.Conv2d(1, 64, kernel_size=4, stride=2, padding=1), nn.BatchNorm2d(64), nn.LeakyReLU(0.2, inplace=True), nn.Conv2d(64, 128, kernel_size=4, stride=2, padding=1), nn.BatchNorm2d(128), nn.LeakyReLU(0.2, inplace=True), ) self.fc = nn.Sequential( nn.Linear(128 * 4 * 4, 1), nn.Sigmoid(), ) def forward(self, x): x = self.conv(x) x = x.view(-1, 128 * 4 * 4) x = self.fc(x) return x # 定义数据集类 class ImageDataset(Dataset): def __init__(self, root_dir, transform=None): self.root_dir = root_dir self.transform = transform self.file_list = os.listdir(root_dir) def __len__(self): return len(self.file_list) def __getitem__(self, idx): img_name = os.path.join(self.root_dir, self.file_list[idx]) image = Image.open(img_name).convert('L') if self.transform: image = self.transform(image) return image # 定义超参数 batch_size = 64 lr = 0.0002 beta1 = 0.5 beta2 = 0.999 num_epochs = 100 latent_size = 100 device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') # 定义数据预处理 transform = transforms.Compose([ transforms.Resize((11, 11)), transforms.ToTensor(), transforms.Normalize(mean=[0.5], std=[0.5]), ]) # 加载数据集 dataset = ImageDataset(root_dir='data', transform=transform) dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True) # 定义生成器和判别器 G = Generator().to(device) D = Discriminator().to(device) # 定义损失函数和优化器 criterion = nn.BCELoss() optimizer_G = optim.Adam(G.parameters(), lr=lr, betas=(beta1, beta2)) optimizer_D = optim.Adam(D.parameters(), lr=lr, betas=(beta1, beta2)) # 训练模型 for epoch in range(num_epochs): for i, data in enumerate(dataloader): # 训练判别器 D.zero_grad() real_images = data.to(device) real_labels = torch.ones(real_images.size(0)).to(device) noise = torch.randn(real_images.size(0), latent_size).to(device) fake_images = G(noise) fake_labels = torch.zeros(fake_images.size(0)).to(device) real_outputs = D(real_images) fake_outputs = D(fake_images.detach()) d_loss = criterion(real_outputs, real_labels) + criterion(fake_outputs, fake_labels) d_loss.backward() optimizer_D.step() # 训练生成器 G.zero_grad() noise = torch.randn(real_images.size(0), latent_size).to(device) fake_images = G(noise) fake_outputs = D(fake_images) g_loss = criterion(fake_outputs, real_labels) g_loss.backward() optimizer_G.step() # 输出当前的损失 if (i+1) % 10 == 0: print('[Epoch %d/%d] [Batch %d/%d] [D loss: %.4f] [G loss: %.4f]' % (epoch+1, num_epochs, i+1, len(dataloader), d_loss.item(), g_loss.item())) # 保存模型 torch.save(G.state_dict(), 'G.pth') torch.save(D.state_dict(), 'D.pth') # 加载模型并生成图像 G = Generator().to(device) G.load_state_dict(torch.load('G.pth')) G.eval() with torch.no_grad(): noise = torch.randn(1, latent_size).to(device) fake_images = G(noise) fake_images = ((fake_images + 1) / 2).cpu().numpy() fake_images = fake_images.reshape((11, 11)) * 255 fake_images = Image.fromarray(fake_images.astype('uint8')).convert('RGB') fake_images.show() ``` 这段代码实现了一个基于 PyTorch 的 GAN 模型,用于生成扩充数据集。在训练过程中,使用了 DataLoader 类来加载数据集,并使用了交叉熵损失函数和 Adam 优化器来训练模型。在训练完成后,将生成器和判别器保存为模型文件,以便后续

相关推荐

最新推荐

recommend-type

Pytorch使用MNIST数据集实现CGAN和生成指定的数字方式

今天小编就为大家分享一篇Pytorch使用MNIST数据集实现CGAN和生成指定的数字方式,具有很好的参考价值,希望对大家有所帮助。一起跟随小编过来看看吧
recommend-type

pytorch学习教程之自定义数据集

自定义数据集 在训练深度学习模型之前,样本集的制作非常重要。在pytorch中,提供了一些接口和类,方便我们定义自己的数据集合,下面完整的试验自定义样本集的整个流程。 开发环境 Ubuntu 18.04 pytorch 1.0 ...
recommend-type

pytorch 语义分割-医学图像-脑肿瘤数据集的载入模块

由于最近目标是完成基于深度学习的脑肿瘤语义分割实验,所以需要用到自定义的数据载入,本文参考了一下博客:https://blog.csdn.net/tuiqdymy/article/details/84779716?utm_source=app,一开始是做的眼底图像分割,...
recommend-type

pytorch实现mnist数据集的图像可视化及保存

今天小编就为大家分享一篇pytorch实现mnist数据集的图像可视化及保存,具有很好的参考价值,希望对大家有所帮助。一起跟随小编过来看看吧
recommend-type

PyTorch版YOLOv4训练自己的数据集—基于Google Colab

Google Colaboratory是谷歌开放的一款研究工具,主要用于机器学习的开发和研究。 工具优势:Google Colab最大的好处是给广大的AI开发者提供了免费的GPU使用。你可以在上面轻松地跑例如:Keras、Tensorflow、Pytorch...
recommend-type

zigbee-cluster-library-specification

最新的zigbee-cluster-library-specification说明文档。
recommend-type

管理建模和仿真的文件

管理Boualem Benatallah引用此版本:布阿利姆·贝纳塔拉。管理建模和仿真。约瑟夫-傅立叶大学-格勒诺布尔第一大学,1996年。法语。NNT:电话:00345357HAL ID:电话:00345357https://theses.hal.science/tel-003453572008年12月9日提交HAL是一个多学科的开放存取档案馆,用于存放和传播科学研究论文,无论它们是否被公开。论文可以来自法国或国外的教学和研究机构,也可以来自公共或私人研究中心。L’archive ouverte pluridisciplinaire
recommend-type

实现实时数据湖架构:Kafka与Hive集成

![实现实时数据湖架构:Kafka与Hive集成](https://img-blog.csdnimg.cn/img_convert/10eb2e6972b3b6086286fc64c0b3ee41.jpeg) # 1. 实时数据湖架构概述** 实时数据湖是一种现代数据管理架构,它允许企业以低延迟的方式收集、存储和处理大量数据。与传统数据仓库不同,实时数据湖不依赖于预先定义的模式,而是采用灵活的架构,可以处理各种数据类型和格式。这种架构为企业提供了以下优势: - **实时洞察:**实时数据湖允许企业访问最新的数据,从而做出更明智的决策。 - **数据民主化:**实时数据湖使各种利益相关者都可
recommend-type

list根据id查询pid 然后依次获取到所有的子节点数据

可以使用递归的方式来实现根据id查询pid并获取所有子节点数据。具体实现可以参考以下代码: ``` def get_children_nodes(nodes, parent_id): children = [] for node in nodes: if node['pid'] == parent_id: node['children'] = get_children_nodes(nodes, node['id']) children.append(node) return children # 测试数
recommend-type

JSBSim Reference Manual

JSBSim参考手册,其中包含JSBSim简介,JSBSim配置文件xml的编写语法,编程手册以及一些应用实例等。其中有部分内容还没有写完,估计有生之年很难看到完整版了,但是内容还是很有参考价值的。