pytorch的BROADCASTING SEMANTICS介绍
时间: 2023-10-16 18:31:24 浏览: 143
pytorch语义分割
在 PyTorch 中,broadcasting 是一种在不同形状的张量之间执行数学运算的方式。它允许我们在不进行显式复制数据的情况下对张量进行操作,从而减少了内存使用和计算时间。
PyTorch 中的 broadcasting 遵循以下规则:
1. 如果两个张量的维度数不同,则将较小的张量的形状通过在前面添加大小为 1 的维度来扩展,直到两个张量的维度数相同。
2. 对于每个维度,如果两个张量的大小相等,或其中一个张量的大小为 1,则可以执行广播。
3. 如果两个张量在某个维度上的大小都不为 1,且大小不相等,则无法执行广播,会报错。
下面是一些示例:
```
import torch
# 例子1
a = torch.tensor([1, 2, 3])
b = torch.tensor([4, 5, 6])
c = a + b
print(c) # tensor([5, 7, 9])
# 例子2
a = torch.tensor([1, 2, 3])
b = torch.tensor([[4], [5], [6]])
c = a + b
print(c) # tensor([[5, 6, 7],
# [6, 7, 8],
# [7, 8, 9]])
# 例子3
a = torch.tensor([1, 2, 3])
b = torch.tensor([4, 5])
c = a + b # 报错:RuntimeError: The size of tensor a (3) must match the size of tensor b (2) at non-singleton dimension 0
```
在例子1中,a 和 b 都是形状为 (3,) 的张量,直接进行加法运算时可以执行广播。
在例子2中,a 是形状为 (3,) 的张量,b 是形状为 (3,1) 的张量,通过添加一个大小为 1 的维度来扩展 a,从而可以执行广播。
在例子3中,a 和 b 在第 0 维的大小不同,无法执行广播,因此会报错。
阅读全文