torch 矩阵乘法
时间: 2023-09-21 22:09:41 浏览: 97
在 PyTorch 中,可以使用 `torch.matmul()` 或 `torch.mm()` 函数进行矩阵乘法操作。
`torch.matmul()` 函数可以用于执行两个张量之间的矩阵乘法,支持高维张量的操作。而 `torch.mm()` 函数则用于执行两个二维张量之间的矩阵乘法。
以下是使用这两个函数的示例代码:
```python
import torch
# 创建两个张量
A = torch.tensor([[1, 2], [3, 4]])
B = torch.tensor([[5, 6], [7, 8]])
# 使用 torch.matmul() 进行矩阵乘法
C = torch.matmul(A, B)
# 使用 torch.mm() 进行矩阵乘法(仅限于二维张量)
D = torch.mm(A, B)
print("torch.matmul():")
print(C)
print("\ntorch.mm():")
print(D)
```
输出结果为:
```
torch.matmul():
tensor([[19, 22],
[43, 50]])
torch.mm():
tensor([[19, 22],
[43, 50]])
```
在上述示例中,我们创建了两个2x2的张量 `A` 和 `B`,然后使用 `torch.matmul()` 和 `torch.mm()` 对它们进行矩阵乘法操作,得到结果张量 `C` 和 `D`。注意,`torch.mm()` 只能用于二维张量的矩阵乘法,而 `torch.matmul()` 可以处理更高维度的张量。
阅读全文