解释torch.bmm
时间: 2023-10-12 09:00:58 浏览: 178
浅谈pytorch torch.backends.cudnn设置作用
5星 · 资源好评率100%
torch.bmm 是 PyTorch 中的一个函数,用于对两个三维张量进行批量矩阵乘法。
假设第一个张量 A 的形状为 (batch_size, n, m),第二个张量 B 的形状为 (batch_size, m, p),则 torch.bmm 函数的输出是一个形状为 (batch_size, n, p) 的三维张量 C,其中 C[i,:,:] = torch.mm(A[i,:,:], B[i,:,:])。
换句话说,torch.bmm 函数对于第一个维度上的每个样本,在第二个维度上进行矩阵乘法,得到一个形状为 (n, p) 的矩阵,最终将这些矩阵按照第一个维度拼接起来,得到形状为 (batch_size, n, p) 的三维张量。
阅读全文