PyTorch 元素融合函数
时间: 2024-01-01 09:07:02 浏览: 157
PyTorch 提供了很多元素融合(Element-wise Fusion)函数,可以在多个 Tensor 之间进行逐元素计算并返回一个新的 Tensor。以下是一些常见的元素融合函数:
1. `torch.add(input, other, alpha=1)`: 将 input 和 other 逐元素相加,并返回一个新的 Tensor。
2. `torch.sub(input, other)`: 将 input 和 other 逐元素相减,并返回一个新的 Tensor。
3. `torch.mul(input, other)`: 将 input 和 other 逐元素相乘,并返回一个新的 Tensor。
4. `torch.div(input, other)`: 将 input 和 other 逐元素相除,并返回一个新的 Tensor。
5. `torch.pow(input, exponent)`: 将 input 中的每个元素取 exponent 次方,并返回一个新的 Tensor。
6. `torch.exp(input)`: 将 input 中的每个元素取指数,并返回一个新的 Tensor。
7. `torch.log(input)`: 将 input 中的每个元素取对数,并返回一个新的 Tensor。
8. `torch.abs(input)`: 将 input 中的每个元素取绝对值,并返回一个新的 Tensor。
9. `torch.sigmoid(input)`: 将 input 中的每个元素应用 sigmoid 函数,并返回一个新的 Tensor。
这些函数都支持广播机制(Broadcasting),即在计算时,会自动将形状不同的 Tensor 扩展到相同的形状,以便能够逐元素计算。例如,如果使用 `torch.add()` 函数计算两个形状不同的 Tensor,PyTorch 会自动将形状不同的 Tensor 扩展到相同的形状,然后逐元素相加。以下是一个示例代码:
```
import torch
a = torch.Tensor([[1, 2], [3, 4]])
b = torch.Tensor([1, 2])
c = torch.add(a, b)
print(c) # 输出 tensor([[2., 4.], [4., 6.]])
```
在这个示例中,使用 `torch.add()` 函数计算 `a` 和 `b` 时,PyTorch 会自动将 `b` 扩展为形状为 `(2, 2)` 的 Tensor,以便逐元素相加。
阅读全文