paddle.split用法
时间: 2024-01-05 22:04:55 浏览: 136
`paddle.split`函数是PaddlePaddle中的一个操作,用于将张量沿指定维度分割成多个子张量。它的语法格式如下:
```python
paddle.split(input, num_or_sections, axis=0)
```
参数说明:
- `input`:待分割的张量,可以是一个变量或常量。
- `num_or_sections`:分割的段数或每一段的长度。当为整数时,表示分割成的段数;当为列表时,表示每一段的长度。
- `axis`:沿着哪个维度进行分割,默认为0,即沿着第一个维度分割。
使用示例:
```python
import paddle
x = paddle.to_tensor([[1, 2, 3], [4, 5, 6], [7, 8, 9]])
print(x)
# 沿着第一个维度分割成3个子张量
y1, y2, y3 = paddle.split(x, num_or_sections=3, axis=0)
print(y1)
print(y2)
print(y3)
# 沿着第二个维度分割成3个子张量
z1, z2, z3 = paddle.split(x, num_or_sections=3, axis=1)
print(z1)
print(z2)
print(z3)
```
输出结果:
```
Tensor(shape=[3, 3], dtype=int64, place=CPUPlace, stop_gradient=True,
[[1, 2, 3],
[4, 5, 6],
[7, 8, 9]])
Tensor(shape=[1, 3], dtype=int64, place=CPUPlace, stop_gradient=True,
[[1, 2, 3]])
Tensor(shape=[1, 3], dtype=int64, place=CPUPlace, stop_gradient=True,
[[4, 5, 6]])
Tensor(shape=[1, 3], dtype=int64, place=CPUPlace, stop_gradient=True,
[[7, 8, 9]])
Tensor(shape=[3, 1], dtype=int64, place=CPUPlace, stop_gradient=True,
[[1],
[4],
[7]])
Tensor(shape=[3, 1], dtype=int64, place=CPUPlace, stop_gradient=True,
[[2],
[5],
[8]])
Tensor(shape=[3, 1], dtype=int64, place=CPUPlace, stop_gradient=True,
[[3],
[6],
[9]])
```
可以看出,将一个 $3 \times 3$ 的张量沿着第一个维度分割成3个子张量时,每个子张量的形状为 $1 \times 3$;沿着第二个维度分割成3个子张量时,每个子张量的形状为 $3 \times 1$。
阅读全文