pytorch matmul
时间: 2023-07-14 07:08:54 浏览: 58
`torch.matmul` 是 PyTorch 中的一个函数,用于执行矩阵乘法操作。它接受两个张量作为输入,并返回它们的矩阵乘积。
`torch.matmul` 可以用于执行多种矩阵乘法操作,具体取决于输入张量的维度。以下是一些常见的用法示例:
1. 两个二维张量的矩阵乘法:
```python
import torch
x = torch.tensor([[1, 2], [3, 4]])
y = torch.tensor([[5, 6], [7, 8]])
result = torch.matmul(x, y)
print(result)
```
输出:
```
tensor([[19, 22],
[43, 50]])
```
2. 一个二维张量和一个一维张量之间的矩阵乘法:
```python
import torch
x = torch.tensor([[1, 2], [3, 4]])
y = torch.tensor([5, 6])
result = torch.matmul(x, y)
print(result)
```
输出:
```
tensor([17, 39])
```
3. 执行批量矩阵乘法:
```python
import torch
x = torch.tensor([[[1, 2], [3, 4]], [[5, 6], [7, 8]]])
y = torch.tensor([[[9, 10], [11, 12]], [[13, 14], [15, 16]]])
result = torch.matmul(x, y)
print(result)
```
输出:
```
tensor([[[31, 34],
[71, 78]],
[[83, 90],
[119, 130]]])
```
需要注意的是,`torch.matmul` 对于张量的形状有一些要求,例如二维张量的最后一个维度和一维张量的第一个维度必须匹配,才能进行矩阵乘法。另外,批量矩阵乘法时需要保持张量的维度匹配。如果输入张量维度不满足要求,将会抛出错误。