使用PyTorch的数据集和数据加载器来逐批加载数据
时间: 2024-03-20 20:39:24 浏览: 78
MLP识别MNIST手写数字数据集(Pytorch)
当你的数据集很大时,你可以使用PyTorch的数据集和数据加载器来逐批加载数据。这样,你可以在内存有限的情况下有效地加载和使用大型数据集。
以下是一个使用PyTorch的数据集和数据加载器来逐批加载数据的示例代码:
```python
import os
import scipy.io as sio
import numpy as np
import torch
from torch.utils.data import Dataset, DataLoader
# 定义自定义数据集类
class MyDataset(Dataset):
def __init__(self, data_path):
self.file_paths = []
for root, dirs, files in os.walk(data_path):
for file in files:
if file.endswith(".mat"):
self.file_paths.append(os.path.join(root, file))
def __len__(self):
return len(self.file_paths)
def __getitem__(self, idx):
mat_data = sio.loadmat(self.file_paths[idx])
np_data = np.array(mat_data['data'])
return np_data
# 定义数据路径和批量大小
data_path = "/path/to/data/folder"
batch_size = 32
# 创建自定义数据集对象和数据加载器对象
dataset = MyDataset(data_path)
dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True)
# 遍历数据加载器并输出数据批次的形状
for i, batch in enumerate(dataloader):
print("Batch ", i, " shape: ", batch.shape)
```
这个代码与之前的代码类似,但是使用了PyTorch的数据集和数据加载器来逐批加载数据。自定义数据集类(MyDataset)用于从磁盘加载.mat文件并将其转换为Numpy数组。数据加载器(DataLoader)用于加载数据集中的批次数据。
请注意,在这个示例代码中,我们使用了shuffle=True来打乱数据集。这是一个很好的实践,因为它可以帮助模型更好地学习数据集中的模式。如果你的数据集已经按照某种顺序排列好了,你可以将shuffle设置为False。
这个示例代码只是一个简单的演示如何使用PyTorch的数据集和数据加载器来逐批加载数据。你可以根据自己的需求修改代码以适应你的数据集和模型。
阅读全文