torch.sparse.FloatTensor
时间: 2024-09-13 14:06:34 浏览: 46
torch_sparse-0.6.10-cp39-cp39-linux_x86_64whl.zip
`torch.sparse.FloatTensor` 是PyTorch库中的一个类,用于表示稀疏张量,特别是在处理大型数据集时,它可以节省大量的内存空间。它主要用于COO (Coordinate List) 格式,其中包含了非零元素的索引和对应的值。
创建`torch.sparse.FloatTensor`需要以下三个参数:
1. `indices`: 非零元素的坐标,通常是一个LongTensor。
2. `values`: 对应于给定坐标值的张量。
3. `sizes`: 表示稀疏张量的整体形状,即使大部分区域是零。
下面是一些操作:
1. **创建**:
```python
i = torch.LongTensor([[0, 1, 1], [2, 1, 0]])
d = torch.tensor([3, 6, 9], dtype=torch.float)
a = torch.sparse.FloatTensor(i, d, torch.Size([2, 3]))
```
这会创建一个2x3的稀疏张量,有3个非零元素,位于位置(0,1),(1,1),(2,1),对应值分别为3, 6, 和9。
2. **转换为密集形式**:
```python
a.to_dense() # 返回一个稠密版本的tensor
```
如你所给出的示例,对于第一个a,结果是[[0, 0, 1, 0, 0], [0, 0, 0, 1, 0], [0, 0, 0, 0, 1], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0]]。
3. **打印输出**:
```python
print(a) # 打印原始的sparse tensor信息
```
阅读全文