pytorch 按照索引拆分
时间: 2023-11-06 17:56:49 浏览: 154
Pytorch Tensor的索引与切片例子
5星 · 资源好评率100%
一个张量的方法是什么?
PyTorch 提供了一个名为 split() 的方法,可以按照指定的维度将一个张量拆分成多个子张量。可以通过指定拆分的维度和每个子张量的大小来进行拆分。例如,可以使用以下代码将一个张量按照第二维度拆分成两个子张量:
```python
import torch
x = torch.randn(2, 4, 6)
x1, x2 = torch.split(x, 2, dim=1)
print(x1)
print(x2)
```
输出结果如下所示:
```
tensor([[[ 0.7661, -1.4321, -0.1708, -0.3841, -0.0089, -0.0578],
[ 0.0434, -0.5774, 0.1823, 0.2379, 0.8545, 0.6631]],
[[-1.4663, 0.5092, 0.5708, -0.2400, -0.3581, -1.2011],
[ 1.3259, -0.1898, 0.4644, 0.0454, 0.1176, -0.6469]]])
tensor([[[-0.5641, -0.6956, -0.6768, -0.4420, -0.3269, -0.9093],
[ 0.6612, 0.7997, -0.3426, -0.3708, -1.2086, -0.5329]],
[[-0.8263, -0.0202, 0.2203, -1.1411, -1.0526, -0.0677],
[-0.0630, 0.0035, -0.3452, 0.2190, 0.0981, 1.2634]]])
```
其中,x1 和 x2 分别是拆分后的两个子张量,每个子张量的大小为 (2, 2, 6)。
阅读全文