np.squeeze()和np.unsqueeze()
时间: 2023-12-03 19:05:01 浏览: 184
np.squeeze()和np.unsqueeze()都是numpy中的函数,用于改变数组的维度。
np.squeeze()函数可以将数组中维度为1的维度去掉,从而降低数组的维度。例如,对于一个形状为(1,5)的数组a,使用np.squeeze(a)函数可以将其转换为形状为(5,)的数组。
np.unsqueeze()函数则是在数组的指定位置插入一个维度为1的维度,从而增加数组的维度。例如,对于一个形状为(5,)的数组a,使用np.unsqueeze(a,0)函数可以在第0个位置插入一个维度为1的维度,从而将其转换为形状为(1,5)的数组。
相关问题
np.squeeze()和np.unsqueeze()介绍和举例
np.squeeze()和np.unsqueeze()都是用于改变数组的维度的函数,但它们的作用是相反的。
np.squeeze()函数可以将数组中维度为1的维度去掉,从而降低数组的维度。如果指定了axis参数,则只有该轴为1时才会被去掉,否则会将所有为1的维度都去掉。下面是一个例子:
举例:
import numpy as np
arr = np.array([[[[1,2,3],[4,5,6]]]])
print(type(arr), arr, arr.shape, sep='\n')
print("==========================")
arr_1 = np.squeeze(arr, axis=0)
print(type(arr_1), arr_1, arr_1.shape, sep='\n')
print("==========================")
arr_2 = np.squeeze(arr, axis=None)
print(type(arr_2), arr_2, arr_2.shape, sep='\n')
输出结果为:
<class 'numpy.ndarray'>
[[[[1 2 3]
[4 5 6]]]]
(1, 1, 2, 3)
==========================
<class 'numpy.ndarray'>
[[[1 2 3]
[4 5 6]]]
(1, 2, 3)
==========================
<class 'numpy.ndarray'>
[[1 2 3]
[4 5 6]]
(2, 3)
可以看到,原数组arr的维度为(1,1,2,3),使用np.squeeze()函数去掉第一个维度后,得到的新数组arr_1的维度为(1,2,3),再去掉所有为1的维度后,得到的新数组arr_2的维度为(2,3)。
相反,np.unsqueeze()函数可以在数组的指定轴上增加一个维度,从而增加数组的维度。下面是一个例子:
举例:
import numpy as np
arr = np.array([1,2,3])
print(type(arr), arr, arr.shape, sep='\n')
print("==========================")
arr_1 = np.expand_dims(arr, axis=0)
print(type(arr_1), arr_1, arr_1.shape, sep='\n')
print("==========================")
arr_2 = np.expand_dims(arr, axis=1)
print(type(arr_2), arr_2, arr_2.shape, sep='\n')
输出结果为:
<class 'numpy.ndarray'>
[1 2 3]
(3,)
==========================
<class 'numpy.ndarray'>
[[1 2 3]]
(1, 3)
==========================
<class 'numpy.ndarray'>
[
]
(3, 1)
可以看到,原数组arr的维度为(3,),使用np.expand_dims()函数在第一个维度上增加一个维度后,得到的新数组arr_1的维度为(1,3),在第二个维度上增加一个维度后,得到的新数组arr_2的维度为(3,1)。
def test(model, path): model.eval() mean_loss = [] for s in ['val', 'test']: image_root = '{}/{}'.format(path, s) gt_root = '{}/{}'.format(path, s) test_loader = test_dataset(image_root, gt_root) dice_bank = [] iou_bank = [] loss_bank = [] acc_bank = [] for i in range(test_loader.size): image, gt = test_loader.load_data() image = image.cuda() with torch.no_grad(): _, _, res = model(image) loss = structure_loss(res, torch.tensor(gt).unsqueeze(0).unsqueeze(0).cuda()) res = res.sigmoid().data.cpu().numpy().squeeze() gt = 1*(gt>0.5) res = 1*(res > 0.5) dice = mean_dice_np(gt, res) iou = mean_iou_np(gt, res) acc = np.sum(res == gt) / (res.shape[0]*res.shape[1]) loss_bank.append(loss.item()) dice_bank.append(dice) iou_bank.append(iou) acc_bank.append(acc) print('{} Loss: {:.4f}, Dice: {:.4f}, IoU: {:.4f}, Acc: {:.4f}'. format(s, np.mean(loss_bank), np.mean(dice_bank), np.mean(iou_bank), np.mean(acc_bank))) mean_loss.append(np.mean(loss_bank)) return mean_loss[0] 中文解释每一句
这段代码是用于测试一个图像分割模型的性能。首先,模型被设置为评估模式,接着对于验证集和测试集分别进行测试。数据集的路径被设置为image_root和gt_root。然后,使用test_dataset函数加载测试集的数据。接下来,对于每一个样本,图像和标注被加载并送入模型进行预测。计算预测结果和标注之间的结构损失并将结果转换为numpy数组。计算Dice, IoU和准确率等性能度量,并将其存储在相应的bank中。最后,输出测试集上的平均损失,Dice, IoU和准确率等性能度量。整个测试过程是在无梯度下进行的。函数的返回值是验证集上的平均损失。
阅读全文