torch.split
时间: 2023-07-03 14:18:06 浏览: 145
`torch.split` 是 PyTorch 中的一个函数,用于将一个张量沿着指定的维度分割成多个子张量。
函数定义:
```python
torch.split(tensor, split_size_or_sections, dim=0)
```
参数说明:
- `tensor`:需要分割的张量。
- `split_size_or_sections`:可以是一个整数,表示每个子张量的大小;也可以是一个列表,表示以指定的大小分割张量。
- `dim`:指定分割的维度,默认为 0。
返回值:
返回一个元组,包含分割后的子张量。
示例:
```python
import torch
x = torch.randn(6, 3)
print(x)
# 沿着第一维分割成两个子张量
x1, x2 = torch.split(x, 3, dim=0)
print(x1)
print(x2)
# 沿着第二维分割成三个子张量
x1, x2, x3 = torch.split(x, [1, 1, 1], dim=1)
print(x1)
print(x2)
print(x3)
```
输出:
```
tensor([[ 1.0139, -0.7239, -0.9594],
[-0.3797, -0.0799, -0.1045],
[ 0.6062, -0.6313, 0.2924],
[-1.2285, -0.7122, -1.1791],
[-0.5107, -1.1765, 0.8406],
[-0.6899, -0.9661, 1.1744]])
tensor([[ 1.0139, -0.7239, -0.9594],
[-0.3797, -0.0799, -0.1045],
[ 0.6062, -0.6313, 0.2924]])
tensor([[-1.2285, -0.7122, -1.1791],
[-0.5107, -1.1765, 0.8406],
[-0.6899, -0.9661, 1.1744]])
tensor([[ 1.0139],
[-0.3797],
[ 0.6062],
[-1.2285],
[-0.5107],
[-0.6899]])
tensor([[-0.7239],
[-0.0799],
[-0.6313],
[-0.7122],
[-1.1765],
[-0.9661]])
tensor([[-0.9594],
[-0.1045],
[ 0.2924],
[-1.1791],
[ 0.8406],
[ 1.1744]])
```
阅读全文