torch.split()函数的用法
时间: 2023-09-13 07:09:51 浏览: 222
`torch.split()`函数是PyTorch中的一个张量操作函数,用于将一个张量沿着指定维度进行分割。
它的语法如下:
```python
torch.split(tensor, split_size_or_sections, dim=0)
```
参数说明:
- `tensor`:要分割的张量。
- `split_size_or_sections`:可以是一个整数或一个整数列表。如果是整数,则表示每个子张量的大小;如果是整数列表,则表示每个子张量的大小。总大小必须与`tensor`在指定维度上的大小相匹配。
- `dim`:指定要沿着哪个维度进行分割。
函数返回一个由分割后的子张量组成的元组。
以下是一些示例:
```python
import torch
# 分割一个张量为两个相等大小的子张量
x = torch.tensor([1, 2, 3, 4, 5, 6])
sub_tensors = torch.split(x, 3)
print(sub_tensors) # (tensor([1, 2, 3]), tensor([4, 5, 6]))
# 按指定大小分割张量
x = torch.tensor([1, 2, 3, 4, 5, 6])
sub_tensors = torch.split(x, [2, 1, 3])
print(sub_tensors) # (tensor([1, 2]), tensor([3]), tensor([4, 5, 6]))
# 沿着指定维度分割张量
x = torch.tensor([[1, 2, 3], [4, 5, 6], [7, 8, 9]])
sub_tensors = torch.split(x, 2, dim=0)
print(sub_tensors) # (tensor([[1, 2, 3], [4, 5, 6]]), tensor([[7, 8, 9]]))
```
这些示例展示了`torch.split()`函数的基本用法,你可以根据自己的需求使用它来分割张量。
阅读全文