datasets.CIFAR10
时间: 2024-01-25 22:13:18 浏览: 91
`datasets.CIFAR10`是PyTorch中`torchvision.datasets`模块下的一个类,用于加载CIFAR10数据集。CIFAR10是一个常用的计算机视觉数据集,包含了10个不同类别的60000个32x32彩色图像,每个类别有6000个图像。该数据集被广泛用于图像分类和目标识别任务。
以下是使用`datasets.CIFAR10`加载CIFAR10数据集的示例代码:
```python
import torchvision.datasets as datasets
# 加载CIFAR10数据集
train_dataset = datasets.CIFAR10(root='./data', train=True, download=True)
test_dataset = datasets.CIFAR10(root='./data', train=False, download=True)
# 获取数据集的大小
train_size = len(train_dataset)
test_size = len(test_dataset)
# 获取数据集的类别标签
classes = train_dataset.classes
# 打印数据集信息
print("训练集大小:", train_size)
print("测试集大小:", test_size)
print("类别标签:", classes)
```
这段代码首先导入了`torchvision.datasets`模块,并使用`datasets.CIFAR10`类分别加载了训练集和测试集。`root`参数指定了数据集的存储路径,`train`参数指定了是否加载训练集,`download`参数指定了是否自动下载数据集(如果数据集不存在的话)。
然后,代码通过`len()`函数获取了训练集和测试集的大小,并使用`classes`属性获取了数据集的类别标签。最后,代码打印了数据集的信息,包括训练集大小、测试集大小和类别标签。
阅读全文