pytorch中的广播机制
时间: 2023-12-11 11:32:35 浏览: 73
在PyTorch中,广播机制是一种自动执行的机制,用于在不同形状的张量之间进行逐元素操作。当两个张量的形状不同时,PyTorch会自动将较小的张量广播到与较大的张量具有相同的形状,以便进行逐元素操作。以下是两个例子:
例子1:
```python
import torch
a = torch.tensor([[1, 2, 3], [4, 5, 6]]) # 形状为(2, 3)
b = torch.tensor([10, 20, 30]) # 形状为(3,)
# 使用广播机制进行逐元素相加
c = a + b # 广播机制会自动将b扩展为(2, 3),使得a和b的形状相同
print(c)
```
输出结果为:
```
tensor([[11, 22, 33],
[14, 25, 36]])
```
例子2:
```python
import torch
a = torch.rand(2, 1, 1, 3) # 2 * 1 * 1 * 3
b = torch.rand(4, 2, 3) # 4 * 2 * 3
c = a + b # 广播机制会自动将a扩展为(2, 4, 2, 3),使得a和b的形状相同
print(c)
print(c.shape)
```
输出结果为:
```
tensor([[[[1.3085, 1.2017, 1.1239],
[1.3085, 1.2017, 1.1239]],
[[1.3085, 1.2017, 1.1239],
[1.3085, 1.2017, 1.1239]],
[[1.3085, 1.2017, 1.1239],
[1.3085, 1.2017, 1.1239]],
[[1.3085, 1.2017, 1.1239],
[1.3085, 1.2017, 1.1239]]],
[[[1.3085, 1.2017, 1.1239],
[1.3085, 1.2017, 1.1239]],
[[1.3085, 1.2017, 1.1239],
[1.3085, 1.2017, 1.1239]],
[[1.3085, 1.2017, 1.1239],
[1.3085, 1.2017, 1.1239]],
[[1.3085, 1.2017, 1.1239],
[1.3085, 1.2017, 1.1239]]]])
torch.Size([2, 4, 2, 3])
```
阅读全文