tf.expand_dims对应torch
时间: 2023-08-31 09:08:29 浏览: 154
tensorflow 利用expand_dims和squeeze扩展和压缩tensor维度方式
tf.expand_dims对应的是torch.unsqueeze函数。两个函数的作用都是在指定的维度上增加一个维度。举个例子,假设一个Tensor的shape是(3,4),那么使用tf.expand_dims或torch.unsqueeze在维度1上增加一个维度后,新的Tensor的shape都会变成(3,1,4)。具体用法如下:
TensorFlow中的tf.expand_dims用法:
```python
import tensorflow as tf
x = tf.constant([[1, 2], [3, 4]])
x = tf.expand_dims(x, 0) # 增加维度0
print(x.shape) # 输出 (1, 2, 2)
```
PyTorch中的torch.unsqueeze用法:
```python
import torch
x = torch.tensor([[1, 2], [3, 4]])
x = torch.unsqueeze(x, 0) # 增加维度0
print(x.shape) # 输出 torch.Size([1, 2, 2])
```
阅读全文