『PyTorch』第五弹_深入理解Tensor对象_下
时间: 2024-09-19 20:01:08 浏览: 39
PyTorch中的Tensor是一个核心数据结构,它是张量库的基础,类似于NumPy数组。Tensor对象可以看作是一维、多维的数值矩阵,支持动态维度和GPU加速计算。以下是关于Tensor对象的一些深入理解:
1. **创建Tensor**:你可以直接通过`torch.tensor()`函数创建Tensor,也可以从Python列表、NumPy数组或者其他Tensor转换而来。
```python
import torch
# 创建一维Tensor
t = torch.tensor([1, 2, 3])
# 或者从列表创建
t_list = [4, 5, 6]
t_list_tensor = torch.tensor(t_list)
```
2. **属性和方法**:Tensor有多种属性如形状(shape)、大小(size),以及诸如索引(indexing)、切片(slicing)、数学运算(math operations)等常用方法。
```python
print(t.shape) # 输出 (3,)
print(t[0]) # 输出 1
t.add_(1) # 在原地加1
```
3. **类型和维度**:Tensor有固定的类型(如int, float),并可以是任意维度,包括标量(0维)、向量(1维)、矩阵(2维)等。
4. **GPU支持**:PyTorch对CUDA(NVIDIA GPU的计算平台)提供了很好的支持,可以通过`.to('cuda')`将Tensor移动到GPU上进行加速。
```python
if torch.cuda.is_available():
t_gpu = t.to('cuda')
```
阅读全文