torch.stack
时间: 2023-10-20 20:03:58 浏览: 208
torch.stack is a PyTorch function that concatenates a sequence of tensors along a new dimension. It takes a sequence of tensors as input and returns a new tensor that has one more dimension than the input tensors.
For example, if we have two tensors of shape (3, 4) and we want to stack them along a new dimension (i.e., make a tensor of shape (2, 3, 4)), we can use torch.stack as follows:
```
import torch
a = torch.randn(3, 4)
b = torch.randn(3, 4)
c = torch.stack([a, b], dim=0)
print(c.shape) # Output: torch.Size([2, 3, 4])
```
Here, we first create two tensors `a` and `b` of shape (3, 4) with random values. We then pass them as a list to `torch.stack` along with the `dim` argument set to 0, which means we want to stack them along the first (new) dimension. The resulting tensor `c` has shape (2, 3, 4).
阅读全文