import os.path import gzip import pickle import os import numpy as np import urllib url_base = 'http://yann.lecun.com/exdb/mnist/' key_file = { 'train_img':'train-images-idx3-ubyte.gz', 'train_label':'train-labels-idx1-ubyte.gz', 'test_img':'t10k-images-idx3-ubyte.gz', 'test_label':'t10k-labels-idx1-ubyte.gz' } dataset_dir = os.path.dirname(os.path.abspath("_file_")) save_file = dataset_dir + "/mnist.pkl" train_num=60000 test_num=10000 img_dim=(1,28,28) img_size=784 def _download(file_name): file_path = dataset_dir+"/"+file_name if os.path.exists(file_path): return print("Downloading"+file_name+" ... ") urllib.request.urlretrieve(url_base + file_name,file_path) print("Done") def download_mnist(): for v in key_file.values(): _download(v) def _load_label(file_name): file_path = dataset_dir+ "/" +file_name print("Converting" + file_name +"to Numpy Array ...") with gzip.open(file_path,'rb') as f: labels = np.frombuffer(f.read(),np.uint8,offset=8) print("Done") return labels def _load_img(file_name): file_path=dataset_dir+"/"+file_name print("Converting"+file_name+"to Numpy Array ...") with gzip.open(file_path,'rb') as f: data = np.frombuffer(f.read(),np.uint8,offset=16) data = data.reshape(-1,img_size) print("Done") return data def _convert_numpy(): dataset = {} dataset['train_img'] = _load_img(key_file['train_img']) dataset['train_label'] = _load_label(key_file['train_label']) dataset['test_img'] = _load_img(key_file['test_img']) dataset['test_label'] = _load_label(key_file['test_label']) return dataset def init_mnist(): download_mnist() dataset = _convert_numpy() print("Creating pickle file ...") with open(save_file,'wb') as f: pickle.dump(dataset,f,-1) print("Done") if __name__ =='__main__': init_mnist()
时间: 2024-04-03 19:31:26 浏览: 349
这段代码是用于下载MNIST数据集并将数据集转换成Numpy数组格式的函数。MNIST数据集是一个手写数字识别数据集,包含了60000张训练图片和10000张测试图片。在函数中,首先定义了数据集的下载地址和四个文件的名称,然后定义了四个函数用于下载和转换数据集。其中,_load_label和_load_img函数用于将标签和图片数据转换成Numpy数组格式,并且在转换过程中使用了gzip库解压缩数据。_convert_numpy函数用于将四个Numpy数组合成一个字典类型的数据集。最后,init_mnist函数用于下载数据集并将数据集转换成Numpy数组格式,并使用pickle库将数据集保存到本地文件中。如果MNIST数据集已经下载并保存到本地文件中,则直接加载本地文件中的数据集。
相关问题
import torch, os, pickle, random import numpy as np from yaml import safe_load as yaml_load from json import dumps as json_dumps def load_data(data_path): with open(data_path, 'rb') as f: data = pickle.load(f) return data def save_model(model, save_path, optimizer=None): os.makedirs(os.path.dirname(save_path), exist_ok=True) data2save = { 'state_dict': model.state_dict(), 'optimizer': optimizer.state_dict(), } torch.save(data2save, save_path) def load_model(model, load_path, optimizer=None): data2load = torch.load(load_path, map_location='cpu') model.load_state_dict(data2load['state_dict']) if optimizer is not None and data2load['optimizer'] is not None: optimizer = data2load['optimizer'] def fix_random_seed_as(seed): random.seed(seed) torch.random.manual_seed(seed) torch.cuda.manual_seed_all(seed) np.random.seed(seed) torch.backends.cudnn.deterministic = True torch.backends.cudnn.benchmark = False if __name__ == "__main__": pass
这段代码是一个示例的辅助函数和主函数。
`load_data` 函数用于加载保存在 pickle 文件中的数据。它接受一个数据路径作为输入,并使用 pickle 库将数据从文件中加载出来。
`save_model` 函数用于保存模型及其优化器的状态。它接受一个模型对象、保存路径和可选的优化器对象作为输入,并将模型的状态字典和优化器的状态字典保存到指定的路径中。
`load_model` 函数用于加载模型及其优化器的状态。它接受一个模型对象、加载路径和可选的优化器对象作为输入,并从指定的路径中加载模型的状态字典,并将其应用于给定的模型对象。如果给定了优化器对象且加载的状态中包含优化器的状态字典,则还会将加载的优化器状态应用于给定的优化器对象。
`fix_random_seed_as` 函数用于设置随机种子,以确保实验的可重复性。它接受一个种子值作为输入,并使用该种子值设置随机数生成器的种子。
最后,`if __name__ == "__main__":` 语句是一个条件判断语句,用于判断当前脚本是否作为主程序运行。如果是主程序运行,则 `pass` 语句表示主函数为空,即没有特定的主要代码逻辑。
import os from random import shuffle import pickle import librosa import tensorflow.keras as keras import librosa.display import numpy as np from matplotlib import pyplot as plt def load_files(audio_dir): files=os.listdir(audio_dir) wav_files=[] for wav in files: if not wav.endswith('.wav'): continue wav_files.append(wav) if not wav_files: print('未找到数据集') shuffle(wav_files) nfiles=len(wav_files) ntrain=int(nfiles*0.7) return wav_files[:ntrain],wav_files[ntrain:] audio_dir='../data/recordings/' dataset_pickle='../tmp/recordings.pkl' train_files,valid_files=load_files(audio_dir) print('训练集样本数为{}\n验证集样本数为{}'.format(len(train_files),load_files(valid_files)))帮我检查错误
这段代码中有两个问题:
1. 在函数 `load_files()` 中,`wav_files` 列表只会包含以 `.wav` 结尾的文件名,因为 `if not wav.endswith('.wav')` 和 `wav_files.append(wav)` 没有缩进到同一个代码块中。你需要将两行代码缩进到同一级别,才能正确的将所有 `.wav` 文件名添加到 `wav_files` 列表中。
2. 在 `print()` 函数中使用了 `load_files()` 函数调用的返回值,而 `load_files()` 函数返回的是两个列表类型的对象,因此会报错。你需要将 `len(load_files(valid_files))` 改为 `len(valid_files)` 才能正确输出验证集样本数。
下面是已经修改过后的代码:
```python
import os
from random import shuffle
import pickle
import librosa
import tensorflow.keras as keras
import librosa.display
import numpy as np
from matplotlib import pyplot as plt
def load_files(audio_dir):
files = os.listdir(audio_dir)
wav_files = []
for wav in files:
if not wav.endswith('.wav'):
continue
wav_files.append(wav)
if not wav_files:
print('未找到数据集')
shuffle(wav_files)
n_files = len(wav_files)
n_train = int(n_files * 0.7)
return wav_files[:n_train], wav_files[n_train:]
audio_dir = '../data/recordings/'
dataset_pickle = '../tmp/recordings.pkl'
train_files, valid_files = load_files(audio_dir)
print('训练集样本数为{}\n验证集样本数为{}'.format(len(train_files), len(valid_files)))
```
希望能帮到你!
阅读全文
相关推荐
![pdf](https://img-home.csdnimg.cn/images/20241231044930.png)
![zip](https://img-home.csdnimg.cn/images/20241231045053.png)
![pdf](https://img-home.csdnimg.cn/images/20241231044930.png)
![](https://csdnimg.cn/download_wenku/file_type_ask_c1.png)
![](https://csdnimg.cn/download_wenku/file_type_ask_c1.png)
![](https://csdnimg.cn/download_wenku/file_type_ask_c1.png)
![](https://csdnimg.cn/download_wenku/file_type_ask_c1.png)
![](https://csdnimg.cn/download_wenku/file_type_ask_c1.png)
![](https://csdnimg.cn/download_wenku/file_type_ask_c1.png)
![](https://csdnimg.cn/download_wenku/file_type_ask_c1.png)
![](https://csdnimg.cn/download_wenku/file_type_ask_c1.png)
![](https://csdnimg.cn/download_wenku/file_type_ask_c1.png)
![](https://csdnimg.cn/download_wenku/file_type_ask_c1.png)
![](https://csdnimg.cn/download_wenku/file_type_ask_c1.png)
![](https://csdnimg.cn/download_wenku/file_type_ask_c1.png)