写一段Pytorch计算批处理矩阵乘法的代码
时间: 2023-11-28 09:47:23 浏览: 76
矩阵乘法代码
下面是一个使用 PyTorch 实现的批处理矩阵乘法的例子:
```python
import torch
# 定义输入矩阵和权重矩阵
batch_size = 10
input_size = 5
hidden_size = 3
input = torch.randn(batch_size, input_size)
weight = torch.randn(input_size, hidden_size)
# 计算矩阵乘积
output = torch.mm(input, weight)
print(output)
```
在这个例子中,我们首先定义了输入矩阵 `input` 和权重矩阵 `weight`,它们的维度分别为 `(batch_size, input_size)` 和 `(input_size, hidden_size)`。接着,我们使用 `torch.mm()` 函数计算了它们的矩阵乘积,并将结果存储在变量 `output` 中。最后,我们打印出了 `output` 的值。
需要注意的是,这里的矩阵乘积是针对整个批次的输入进行计算的,因此输出矩阵的维度为 `(batch_size, hidden_size)`。如果想要对单个输入进行计算,可以将输入矩阵的维度设置为 `(1, input_size)`。
阅读全文