我想取一个形状为(batch_szie=2, seq_len=768, d_model=768)的张量前128个seq_len的张量,我该怎么写代码?
时间: 2024-05-12 09:15:47 浏览: 135
keras model.fit 解决validation_spilt=num 的问题
你可以使用切片操作符 `[..., :128, :]` 来取出前128个 `seq_len` 的张量,其中 `...` 表示所有的批次和其他轴。以下是示例代码:
```python
import torch
tensor = torch.randn(2, 768, 768)
slice_tensor = tensor[..., :128, :]
print(slice_tensor.shape) # 输出 torch.Size([2, 128, 768])
```
在这里,我们使用了 PyTorch 中的 `torch.randn()` 函数创建了一个形状为 `(2, 768, 768)` 的张量 `tensor`,然后使用切片操作符 `[..., :128, :]` 取出前 128 个 `seq_len` 的张量,并将其保存在 `slice_tensor` 中。最后打印出 `slice_tensor` 的形状,输出为 `torch.Size([2, 128, 768])`。
阅读全文