如何取一个形状为(batch_size=2,seq_len=128,d_model=768)的张量的第1个维度seq_len的前64的张量?
时间: 2023-12-03 11:43:47 浏览: 95
pytorch dataloader 取batch_size时候出现bug的解决方式
可以使用以下代码来取出第1个维度seq_len的前64的张量:
```python
import torch
# 创建一个形状为(batch_size=2, seq_len=128, d_model=768)的张量
tensor = torch.randn(2, 128, 768)
# 取出第1个维度seq_len的前64的张量
output = tensor[:, :64, :]
print(output.shape)
```
这里使用了Python中的切片操作,`[:, :64, :]`表示取出所有batch_size和d_model的维度,但是只取seq_len的前64个元素。最后打印output的形状,应该为(batch_size=2, seq_len=64, d_model=768)。
阅读全文