torch.mm函数什么意思
时间: 2023-05-26 13:02:03 浏览: 99
torch.mm函数是PyTorch中的矩阵乘法函数,用于计算两个二维张量之间的矩阵乘积,也就是常见的矩阵乘法运算。其参数为两个二维张量,参数分别为dim1×dim2和dim2×dim3的矩阵,函数返回值为dim1×dim3的矩阵。
例如,给定两个矩阵A和B,它们的维度分别为$m \times n$和$n \times p$,则它们的矩阵乘积为$C = A \times B$,其中矩阵C的维度为$m \times p$,即:
$$\begin{equation} \begin{pmatrix} a_{1,1} & a_{1,2} & \cdots & a_{1,n} \\ a_{2,1} & a_{2,2} & \cdots & a_{2,n} \\ \vdots & \vdots & \ddots & \vdots \\ a_{m,1} & a_{m,2} & \cdots & a_{m,n} \end{pmatrix} \times \begin{pmatrix} b_{1,1} & b_{1,2} & \cdots & b_{1,p} \\ b_{2,1} & b_{2,2} & \cdots & b_{2,p} \\ \vdots & \vdots & \ddots & \vdots \\ b_{n,1} & b_{n,2} & \cdots & b_{n,p} \end{pmatrix} = \begin{pmatrix} c_{1,1} & c_{1,2} & \cdots & c_{1,p} \\ c_{2,1} & c_{2,2} & \cdots & c_{2,p} \\ \vdots & \vdots & \ddots & \vdots \\ c_{m,1} & c_{m,2} & \cdots & c_{m,p} \end{pmatrix} \end{equation}$$
其中$c_{i,j}=\sum_{k=1}^{n}a_{i,k}b_{k,j}$,即为矩阵A的第i行与矩阵B的第j列的内积。在PyTorch中,可以使用torch.mm()函数进行矩阵乘积的计算。
阅读全文