np.squeeze()和np.unsqueeze()介绍和举例
时间: 2024-06-13 13:04:07 浏览: 158
np.squeeze()和np.unsqueeze()都是用于改变数组的维度的函数,但它们的作用是相反的。
np.squeeze()函数可以将数组中维度为1的维度去掉,从而降低数组的维度。如果指定了axis参数,则只有该轴为1时才会被去掉,否则会将所有为1的维度都去掉。下面是一个例子:
举例:
import numpy as np
arr = np.array([[[[1,2,3],[4,5,6]]]])
print(type(arr), arr, arr.shape, sep='\n')
print("==========================")
arr_1 = np.squeeze(arr, axis=0)
print(type(arr_1), arr_1, arr_1.shape, sep='\n')
print("==========================")
arr_2 = np.squeeze(arr, axis=None)
print(type(arr_2), arr_2, arr_2.shape, sep='\n')
输出结果为:
<class 'numpy.ndarray'>
[[[[1 2 3]
[4 5 6]]]]
(1, 1, 2, 3)
==========================
<class 'numpy.ndarray'>
[[[1 2 3]
[4 5 6]]]
(1, 2, 3)
==========================
<class 'numpy.ndarray'>
[[1 2 3]
[4 5 6]]
(2, 3)
可以看到,原数组arr的维度为(1,1,2,3),使用np.squeeze()函数去掉第一个维度后,得到的新数组arr_1的维度为(1,2,3),再去掉所有为1的维度后,得到的新数组arr_2的维度为(2,3)。
相反,np.unsqueeze()函数可以在数组的指定轴上增加一个维度,从而增加数组的维度。下面是一个例子:
举例:
import numpy as np
arr = np.array([1,2,3])
print(type(arr), arr, arr.shape, sep='\n')
print("==========================")
arr_1 = np.expand_dims(arr, axis=0)
print(type(arr_1), arr_1, arr_1.shape, sep='\n')
print("==========================")
arr_2 = np.expand_dims(arr, axis=1)
print(type(arr_2), arr_2, arr_2.shape, sep='\n')
输出结果为:
<class 'numpy.ndarray'>
[1 2 3]
(3,)
==========================
<class 'numpy.ndarray'>
[[1 2 3]]
(1, 3)
==========================
<class 'numpy.ndarray'>
[
]
(3, 1)
可以看到,原数组arr的维度为(3,),使用np.expand_dims()函数在第一个维度上增加一个维度后,得到的新数组arr_1的维度为(1,3),在第二个维度上增加一个维度后,得到的新数组arr_2的维度为(3,1)。
阅读全文