python torch einsum
时间: 2024-12-31 08:42:08 浏览: 14
### Python 中 `torch.einsum` 函数的使用方法与案例
#### 基础概念
`torch.einsum` 是 PyTorch 提供的一个强大工具,用于执行爱因斯坦求和约定的操作。该操作允许用户指定输入张量之间的复杂运算模式,并能高效地计算这些模式下的乘积、转置以及各种维度上的聚合。
#### 使用语法
基本形式如下:
```python
torch.einsum(equation, *operands)
```
其中 `equation` 参数定义了各个参与运算的张量间的关系表达式;而 `*operands` 则是要进行运算的具体张量对象列表[^1]。
#### 实际例子说明
##### 示例一:矩阵相乘
对于两个二维张量 A 和 B 来说,如果想要实现标准意义上的矩阵乘法,则可以通过以下方式来调用 `einsum` 方法完成相同的功能:
```python
import torch
A = torch.randn((3, 4))
B = torch.randn((4, 5))
C = torch.einsum('ij,jk->ik', (A, B)) # 进行矩阵乘法操作
print(C.shape) # 输出应为 (3, 5),即 C=A*B 的结果形状
```
这里 `'ij,jk->ik'` 表达式的含义是从左至右依次表示第一个参数中的索引位置(i 和 j),第二个参数中的索引位置(j 和 k),最后箭头右边的部分指定了输出的结果应该具有的索引结构——也就是新的 i 和 k 组合而成的新轴线方向上形成最终返回值 C。
##### 示例二:批量向量内积
当处理多个样本的数据时,比如有 N 个长度相同的向量组成的批次数据 X 和 Y,可以直接利用 `einsum` 计算每一对对应位置处元素间的点积得到一个新的包含这 N 个数值的一维数组 Z:
```python
X = torch.rand([8, 7]) # 批次大小为8,特征数为7
Y = torch.rand_like(X)
Z = torch.einsum('ni,ni->n', [X, Y])
print(Z.size()) # 应打印出 torch.Size([8]), 即代表八个不同的内积结果
```
在这个例子中,由于我们希望保留批次数目不变的同时仅对各条记录内部做累积加权平均,因此选择了 ni 形式的标记作为共同作用域内的共享标签 n 并指向同一个物理空间坐标系里的不同实例个体。
##### 示例三:三维张量收缩
考虑更复杂的场景,假设有三个三维张量 P,Q,R 需要按照一定规则相互作用并压缩成更低维度的形式 S:
```python
P = torch.ones([2, 3, 4])
Q = torch.ones([3, 4, 5])
R = torch.ones([4])
S = torch.einsum('ijk,jkl,l->il', (P, Q, R))
print(S.shape) # 结果应该是 (2, 5), 对角线上每个元素等于原张量相应部分累加之和
```
上述代码片段展示了如何通过精心设计的字符串方程将多层嵌套结构简化为更加直观易懂的目标形态,在此过程中实现了高效的内存访问路径优化。
阅读全文