Torch.repeat_interleave用法
时间: 2024-06-09 20:07:48 浏览: 15
`torch.repeat_interleave` 函数的作用是将输入张量沿着指定的维度重复。可以使用该函数在张量中沿着指定的维度复制元素或者子张量。该函数的语法如下:
```python
torch.repeat_interleave(input, repeats, dim=None)
```
其中,参数说明如下:
- `input`:输入 Tensor。
- `repeats`:重复次数,可以为整数或者是一个 Tensor。如果 `repeats` 是一个整数,则沿着指定的维度重复该次数;如果 `repeats` 是一个 Tensor,则需要与 `input` 在指定维度上大小相同,表示沿着该维度上的每个元素需要重复的次数。
- `dim`:沿着该维度进行重复,默认为 `None`,表示将输入张量的所有元素进行重复。
示例代码如下:
```python
import torch
# 示例1:将一维向量中的每个元素重复 3 次
x = torch.tensor([1, 2, 3])
y = torch.repeat_interleave(x, 3)
print(y) # tensor([1, 1, 1, 2, 2, 2, 3, 3, 3])
# 示例2:将二维矩阵中的每个元素沿着行维度重复 2 次,沿着列维度重复 3 次
x = torch.tensor([[1, 2], [3, 4]])
y = torch.repeat_interleave(x, (2, 3), dim=1)
print(y)
# tensor([[1, 1, 1, 2, 2, 2],
# [3, 3, 3, 4, 4, 4],
# [1, 1, 1, 2, 2, 2],
# [3, 3, 3, 4, 4, 4]])
```
在示例1中,将一维向量 `[1, 2, 3]` 中的每个元素重复 3 次,得到了一维向量 `[1, 1, 1, 2, 2, 2, 3, 3, 3]`。
在示例2中,将二维矩阵 `[[1, 2], [3, 4]]` 中的每个元素沿着行维度重复 2 次,沿着列维度重复 3 次,得到了一个新的二维矩阵。其中,新矩阵的第一行是 `[1, 1, 1, 2, 2, 2]`,第二行是 `[3, 3, 3, 4, 4, 4]`,第三行和第四行与第一行和第二行相同。