torch中的拆分函数
时间: 2023-09-08 09:16:48 浏览: 215
PyTorch中Tensor的拼接与拆分的实现
5星 · 资源好评率100%
在torch中,可以使用split函数实现张量的拆分。该函数的语法为:
```python
torch.split(tensor, split_size_or_sections, dim=0)
```
其中,`tensor`是待拆分的张量,`split_size_or_sections`表示拆分的大小或者拆分的份数,`dim`表示拆分的维度(默认为0)。如果`split_size_or_sections`是一个整数,则表示在指定维度上平均拆分成几份;如果`split_size_or_sections`是一个列表或元组,则表示按照指定的大小进行拆分。
下面是一个示例:
```python
import torch
# 创建一个形状为[4, 8]的张量
x = torch.randn(4, 8)
# 在第1维上拆分成2份
x1, x2 = torch.split(x, split_size_or_sections=2, dim=1)
print(x1.shape) # 输出torch.Size([4, 2])
print(x2.shape) # 输出torch.Size([4, 6])
```
在这个示例中,我们将一个形状为[4, 8]的张量在第1维上拆分成2份,得到了两个形状分别为[4, 2]和[4, 6]的张量。
阅读全文