没有合适的资源?快使用搜索试试~ 我知道了~
首页pytorch实现建立自己的数据集(以mnist为例)
pytorch实现建立自己的数据集(以mnist为例)
13 下载量 48 浏览量
更新于2023-03-03
评论
收藏 190KB PDF 举报
本文将原始的numpy array数据在pytorch下封装为Dataset类的数据集,为后续深度网络训练提供数据。 加载并保存图像信息 首先导入需要的库,定义各种路径。 import os import matplotlib from keras.datasets import mnist import numpy as np from torch.utils.data.dataset import Dataset from PIL import Image import scipy.misc root_path = 'E:/coding_ex/pytorch/Alexnet/data/'
资源详情
资源评论
资源推荐
pytorch实现建立自己的数据集实现建立自己的数据集(以以mnist为例为例)
本文将原始的numpy array数据在pytorch下封装为Dataset类的数据集,为后续深度网络训练提供数据。
加载并保存图像信息加载并保存图像信息
首先导入需要的库,定义各种路径。
import os
import matplotlib
from keras.datasets import mnist
import numpy as np
from torch.utils.data.dataset import Dataset
from PIL import Image
import scipy.misc
root_path = 'E:/coding_ex/pytorch/Alexnet/data/'
base_path = 'baseset/'
training_path = 'trainingset/'
test_path = 'testset/'
这里将数据集分为三类,baseset为所有数据(trainingset+testset),trainingset是训练集,testset是测试集。直接通过
keras.dataset加载mnist数据集,不能自动下载的话可以手动下载.npz并保存至相应目录下。
def LoadData(root_path, base_path, training_path, test_path):
(x_train, y_train), (x_test, y_test) = mnist.load_data()
x_baseset = np.concatenate((x_train, x_test))
y_baseset = np.concatenate((y_train, y_test))
train_num = len(x_train)
test_num = len(x_test)
#baseset
file_img = open((os.path.join(root_path, base_path)+'baseset_img.txt'),'w')
file_label = open((os.path.join(root_path, base_path)+'baseset_label.txt'),'w')
for i in range(train_num + test_num):
file_img.write(root_path + base_path + 'img/' + str(i) + '.png') #name
file_label.write(str(y_baseset[i])+'') #label
# scipy.misc.imsave(root_path + base_path + '/img/'+str(i) + '.png', x_baseset[i])
matplotlib.image.imsave(root_path + base_path + 'img/'+str(i) + '.png', x_baseset[i])
file_img.close()
file_label.close()
#trainingset
file_img = open((os.path.join(root_path, training_path)+'trainingset_img.txt'),'w')
file_label = open((os.path.join(root_path, training_path)+'trainingset_label.txt'),'w')
for i in range(train_num):
file_img.write(root_path + training_path + 'img/' + str(i) + '.png') #name
file_label.write(str(y_train[i])+'') #label
# scipy.misc.imsave(root_path + training_path + '/img/'+str(i) + '.png', x_train[i])
matplotlib.image.imsave(root_path + training_path + 'img/'+str(i) + '.png', x_train[i])
file_img.close()
file_label.close()
#testset
file_img = open((os.path.join(root_path, test_path)+'testset_img.txt'),'w')
file_label = open((os.path.join(root_path, test_path)+'testset_label.txt'),'w')
for i in range(test_num):
file_img.write(root_path + test_path + 'img/' + str(i) + '.png') #name
file_label.write(str(y_test[i])+'') #label
# scipy.misc.imsave(root_path + test_path + '/img/'+str(i) + '.png', x_test[i])
matplotlib.image.imsave(root_path + test_path + 'img/'+str(i) + '.png', x_test[i])
file_img.close()
file_label.close()
使用这段代码时,需要建立相应的文件夹及.txt文件,./data文件夹结构如下:
weixin_38721691
- 粉丝: 4
- 资源: 907
上传资源 快速赚钱
- 我的内容管理 收起
- 我的资源 快来上传第一个资源
- 我的收益 登录查看自己的收益
- 我的积分 登录查看自己的积分
- 我的C币 登录后查看C币余额
- 我的收藏
- 我的下载
- 下载帮助
会员权益专享
最新资源
- RTL8188FU-Linux-v5.7.4.2-36687.20200602.tar(20765).gz
- c++校园超市商品信息管理系统课程设计说明书(含源代码) (2).pdf
- 建筑供配电系统相关课件.pptx
- 企业管理规章制度及管理模式.doc
- vb打开摄像头.doc
- 云计算-可信计算中认证协议改进方案.pdf
- [详细完整版]单片机编程4.ppt
- c语言常用算法.pdf
- c++经典程序代码大全.pdf
- 单片机数字时钟资料.doc
- 11项目管理前沿1.0.pptx
- 基于ssm的“魅力”繁峙宣传网站的设计与实现论文.doc
- 智慧交通综合解决方案.pptx
- 建筑防潮设计-PowerPointPresentati.pptx
- SPC统计过程控制程序.pptx
- SPC统计方法基础知识.pptx
资源上传下载、课程学习等过程中有任何疑问或建议,欢迎提出宝贵意见哦~我们会及时处理!
点击此处反馈
安全验证
文档复制为VIP权益,开通VIP直接复制
信息提交成功
评论0