python torch.squeeze
时间: 2023-12-21 20:04:52 浏览: 65
基于torch.nn的回归、二分类、多分类任务代码
`torch.squeeze()` 是 PyTorch 中的一个函数,用于去除张量(tensor)中维度为 1 的维度。它的作用是将形状为 (1, N) 或 (N, 1) 的张量转换为形状为 (N,) 的张量,即去除维度为 1 的维度。
例如,假设有一个形状为 (1, 3) 的张量 t:
```
tensor([[1, 2, 3]])
```
使用 `torch.squeeze(t)` 可以将其转换为形状为 (3,) 的张量:
```
tensor([1, 2, 3])
```
在调用 `torch.squeeze()` 函数时,你也可以指定要去除的维度。例如,如果张量 t 的形状为 (1, 3, 1):
```
tensor([[[1], [2], [3]]])
```
你可以使用 `torch.squeeze(t, dim=0)` 或 `t.squeeze(dim=0)` 来去除第一个维度,得到形状为 (3, 1) 的张量:
```
tensor([[1], [2], [3]])
```
总结起来,`torch.squeeze()` 函数的作用是压缩张量的维度,删除大小为 1 的维度。
阅读全文