pytorch怎么把数据集加载到GPU
时间: 2023-11-17 07:06:08 浏览: 133

pytorch 自定义数据集加载方法

将数据集加载到GPU需要使用PyTorch中的DataLoader和torch.cuda模块。
首先,需要确定设备是否支持CUDA:
```python
import torch
if torch.cuda.is_available():
device = torch.device('cuda')
else:
device = torch.device('cpu')
```
然后,将数据集转换为PyTorch张量,并将其放置在定义的设备上:
```python
from torchvision import datasets, transforms
transform = transforms.Compose([transforms.ToTensor()])
train_data = datasets.MNIST(root='data', train=True, download=True, transform=transform)
train_loader = torch.utils.data.DataLoader(train_data, batch_size=64, shuffle=True)
images, labels = next(iter(train_loader))
images = images.to(device)
labels = labels.to(device)
```
在这里,我们使用MNIST数据集作为示例。首先,我们将数据集转换为张量,然后使用DataLoader将其转换为批次。最后,我们将批次中的图像和标签移动到定义的设备上。
注意:如果你使用的是自己的数据集,需要根据数据集的格式进行相应的修改。
阅读全文
相关推荐
















