transforms.ToTensor()
时间: 2023-06-22 19:27:16 浏览: 150
`transforms.ToTensor()` 是 PyTorch 中的一个数据变换(data transformation),它将 PIL 图像或 numpy 数组(array)转换为 PyTorch 张量(Tensor)。具体来说,它会将图像或数组中的数据类型转换为 `torch.FloatTensor`,并将数值范围从 $[0, 255]$(对应图像或数组中的像素值)缩放到 $[0, 1]$(对应张量中的数值)。如果输入是一个灰度图像,它将得到一个形状为 `(1, H, W)` 的张量,其中 `H` 和 `W` 分别表示图像的高度和宽度;如果输入是一个 RGB 图像,它将得到一个形状为 `(3, H, W)` 的张量,其中第一个维度表示颜色通道(红、绿、蓝)。
在深度学习中,我们通常需要将输入数据转换为张量,并对其进行标准化等预处理,以便于网络的训练和推理。`transforms.ToTensor()` 是常用的数据变换之一,通常在数据预处理的流水线中使用,例如:
```python
from torchvision import transforms
transform = transforms.Compose([
transforms.Resize((224, 224)),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])
```
这个代码片段定义了一个 `transform` 变量,它将输入图像先缩放到 $224 \times 224$ 的大小,然后通过 `transforms.ToTensor()` 将其转换为张量,最后使用 `transforms.Normalize()` 函数进行标准化。这个变换可以用于对图像数据进行预处理,以便于输入到深度学习模型中。
阅读全文