torch.stack.T.repeat
时间: 2023-11-27 21:43:24 浏览: 105
torch.stack是PyTorch中的一个函数,用于将多个张量按照指定维度进行堆叠,生成一个新的张量。它的用法如下:
```python
import torch
x = torch.tensor([1, 2, 3])
y = torch.tensor([4, 5, 6])
z = torch.tensor([7, 8, 9])
stacked = torch.stack([x, y, z], dim=0)
print(stacked)
```
输出如下:
```
tensor([[1, 2, 3],
[4, 5, 6],
[7, 8, 9]])
```
这里我们将三个一维张量按照第0个维度进行堆叠,生成了一个三行三列的二维张量。
T是PyTorch中的一个张量操作,用于将张量进行转置。它的用法如下:
```python
import torch
x = torch.tensor([[1, 2], [3, 4], [5, 6]])
transposed = x.T
print(transposed)
```
输出如下:
```
tensor([[1, 3, 5],
[2, 4, 6]])
```
这里我们将一个三行两列的二维张量进行转置,生成了一个两行三列的二维张量。
repeat是PyTorch中的一个张量操作,用于将张量按照指定维度进行复制,生成一个新的张量。它的用法如下:
```python
import torch
x = torch.tensor([1, 2, 3])
repeated = x.repeat(3, 1) # 在维度0上复制3次,在维度1上复制1次
print(repeated)
```
输出如下:
```
tensor([[1, 2, 3],
[1, 2, 3],
[1, 2, 3]])
```
这里我们将一个一维张量按照维度0进行复制三次,按照维度1进行复制一次,生成了一个三行三列的二维张量。
阅读全文