torch.einsum
时间: 2023-10-31 18:58:37 浏览: 95
torch.einsum is a function in PyTorch that performs Einstein summation notation on tensors. It allows for efficient computation of complex tensor operations without the need for explicit loops and indexing. The function takes in a string expression that defines the operation to be performed and one or more input tensors. The output tensor is the result of the operation.
The syntax for torch.einsum is as follows:
```python
torch.einsum(equation, *operands)
```
where `equation` is a string that specifies the operation to be performed using Einstein summation notation and `*operands` are the input tensors.
For example, the following code performs a matrix multiplication using torch.einsum:
```python
import torch
# create input tensors
x = torch.randn(2, 3)
y = torch.randn(3, 4)
# perform matrix multiplication using torch.einsum
z = torch.einsum('ij, jk -> ik', x, y)
print(z)
```
This code creates two input tensors, `x` and `y`, with dimensions `(2, 3)` and `(3, 4)` respectively. It then performs a matrix multiplication using torch.einsum with the equation `'ij, jk -> ik'`, which specifies that the first tensor `x` should be multiplied with the second tensor `y` and the resulting tensor should have dimensions `(2, 4)`. The resulting tensor is stored in the variable `z` and printed to the console.
阅读全文