np.save()函数详解
时间: 2023-05-11 13:02:40 浏览: 91
np.save()函数是numpy库中用于将数组保存到磁盘的函数。它的语法为:
np.save(file, arr, allow_pickle=True, fix_imports=True)
其中,file参数是保存文件的路径和名称,arr参数是要保存的数组,allow_pickle和fix_imports是可选参数,分别表示是否允许使用pickle序列化和是否修复导入问题。
使用np.save()函数可以将数组保存为.npy格式的文件,可以使用np.load()函数加载这些文件。这个函数在数据分析和机器学习中非常常用。
相关问题
np.save函数详解
`np.save` 函数是 NumPy 库中用于将数组数据保存到文件的函数,它的语法为:
```python
np.save(file, arr, allow_pickle=True, fix_imports=True)
```
其中:
- `file`:要保存到的文件名或 Python 文件对象。
- `arr`:要保存的数组数据。
- `allow_pickle`:可选参数,用于指定是否允许保存 pickle 序列化的对象,默认为 True。
- `fix_imports`:可选参数,用于指定是否尝试将 Python 2 中的 pickle 对象转换为 Python 3 中的对象,默认为 True。
`np.save` 函数将数组数据保存为 `.npy` 格式的文件,该文件包含 NumPy 数组的数据和元数据。
下面是一个示例,将数组 `a` 保存到文件 `array.npy` 中:
```python
import numpy as np
a = np.array([1, 2, 3, 4, 5])
np.save('array.npy', a)
```
如果要读取保存的数组数据,可以使用 `np.load` 函数,它的语法为:
```python
np.load(file, mmap_mode=None, allow_pickle=True, fix_imports=True, encoding='ASCII')
```
其中:
- `file`:要读取的 `.npy` 文件名或 Python 文件对象。
- `mmap_mode`:可选参数,用于指定内存映射模式,默认为 None。
- `allow_pickle`:可选参数,用于指定是否允许读取 pickle 序列化的对象,默认为 True。
- `fix_imports`:可选参数,用于指定是否尝试将 Python 2 中的 pickle 对象转换为 Python 3 中的对象,默认为 True。
- `encoding`:可选参数,用于指定读取文本数据时使用的字符编码,默认为 'ASCII'。
下面是一个示例,读取文件 `array.npy` 中保存的数组数据:
```python
import numpy as np
a = np.load('array.npy')
print(a)
```
输出结果为:
```
[1 2 3 4 5]
```
除了 `.npy` 格式的文件,NumPy 还支持保存为 `.npz` 格式的文件,该文件包含多个 NumPy 数组的数据和元数据。可以使用 `np.savez` 函数将多个数组保存到一个 `.npz` 文件中,它的语法为:
```python
np.savez(file, *args, **kwds)
```
其中:
- `file`:要保存到的文件名或 Python 文件对象。
- `*args`:要保存的数组数据。
- `**kwds`:可选参数,用于指定数组数据的名称。
下面是一个示例,将数组 `a` 和 `b` 保存到文件 `arrays.npz` 中:
```python
import numpy as np
a = np.array([1, 2, 3, 4, 5])
b = np.array([6, 7, 8, 9, 10])
np.savez('arrays.npz', a=a, b=b)
```
如果要读取保存的多个数组数据,可以使用 `np.load` 函数,它的语法为:
```python
np.load(file, mmap_mode=None, allow_pickle=True, fix_imports=True, encoding='ASCII')
```
其中:
- `file`:要读取的 `.npz` 文件名或 Python 文件对象。
- `mmap_mode`:可选参数,用于指定内存映射模式,默认为 None。
- `allow_pickle`:可选参数,用于指定是否允许读取 pickle 序列化的对象,默认为 True。
- `fix_imports`:可选参数,用于指定是否尝试将 Python 2 中的 pickle 对象转换为 Python 3 中的对象,默认为 True。
- `encoding`:可选参数,用于指定读取文本数据时使用的字符编码,默认为 'ASCII'。
下面是一个示例,读取文件 `arrays.npz` 中保存的数组数据:
```python
import numpy as np
data = np.load('arrays.npz')
a = data['a']
b = data['b']
print(a)
print(b)
```
输出结果为:
```
[1 2 3 4 5]
[ 6 7 8 9 10]
```
总之,`np.save` 函数可以将单个数组保存为 `.npy` 格式的文件,`np.load` 函数可以读取 `.npy` 文件中的单个数组数据;`np.savez` 函数可以将多个数组保存为 `.npz` 格式的文件,`np.load` 函数可以读取 `.npz` 文件中的多个数组数据。这些函数的使用非常简单,但是在实际应用中非常有用。
loftr代码详解demo.py详解
demo.py 是 LoFTR 算法的主要测试文件,用于演示 LoFTR 算法在图像匹配和三维重建任务上的效果,下面我将对其中的关键代码进行详细解释。
首先是导入必要的包和模块:
```python
import os
import sys
import time
import argparse
import numpy as np
import cv2
import torch
import matplotlib.pyplot as plt
from utils.tools import (
Timer,
str2bool,
load_image,
plot_keypoints,
plot_matches,
plot_reprojection,
plot_trajectory,
save_trajectory,
)
from utils.evaluation import compute_repeatability, compute_precision_recall
from models.loftr import LoFTR, default_cfg
```
其中 `os`、`sys`、`time`、`argparse`、`numpy`、`cv2`、`torch`、`matplotlib.pyplot` 都是 Python 常用的标准库或第三方库,此处不再赘述。`plot_keypoints`、`plot_matches`、`plot_reprojection` 和 `plot_trajectory` 是自定义的用于可视化结果的函数。`compute_repeatability` 和 `compute_precision_recall` 是用于评估匹配和重建结果的函数,详见 `evaluation.py` 文件。`LoFTR` 是 LoFTR 模型的主要类,`default_cfg` 是 LoFTR 模型的默认配置。
然后是解析命令行参数:
```python
parser = argparse.ArgumentParser(
description="LoFTR: Detector-Suppressor for 3D Local Features",
formatter_class=argparse.ArgumentDefaultsHelpFormatter,
)
parser.add_argument("img0_path", type=str, help="path to the reference image")
parser.add_argument("img1_path", type=str, help="path to the query image")
parser.add_argument("--weights", type=str, default=default_cfg["weights"], help="path to the pretrained model weights")
parser.add_argument("--cfg", type=str, default=default_cfg["cfg"], help="path to the config file")
parser.add_argument("--matcher", type=str, default=default_cfg["matcher"], help="feature matcher")
parser.add_argument("--suppression", type=str, default=default_cfg["suppression"], help="feature suppressor")
parser.add_argument("--top_k", type=int, default=default_cfg["top_k"], help="keep top_k keypoints")
parser.add_argument("--max_length", type=int, default=default_cfg["max_length"], help="maximum sequence length")
parser.add_argument("--resize", type=str2bool, nargs="?", const=True, default=True, help="resize input images")
parser.add_argument("--show", type=str2bool, nargs="?", const=True, default=False, help="show results")
parser.add_argument(
"--eval", type=str2bool, nargs="?", const=True, default=False, help="evaluate repeatability and matching performance"
)
parser.add_argument("--output_dir", type=str, default="outputs", help="output directory")
args = parser.parse_args()
```
其中 `img0_path` 和 `img1_path` 分别表示参考图像和查询图像的路径。`weights`、`cfg`、`matcher`、`suppression`、`top_k`、`max_length` 分别表示 LoFTR 模型的权重文件、配置文件、特征匹配器、特征抑制器、保留的关键点数量、序列的最大长度。`resize`、`show`、`eval`、`output_dir` 分别表示是否对输入图像进行缩放、是否显示结果、是否评估性能、输出结果的目录。
接下来是读取图像并将其转换为张量:
```python
img0 = load_image(args.img0_path, resize=args.resize)
img1 = load_image(args.img1_path, resize=args.resize)
```
其中 `load_image` 函数用于加载图像,将其转换为 BGR 格式,并将其缩放到固定大小,返回值是一个 `torch.Tensor` 对象。
然后是加载 LoFTR 模型:
```python
loftr = LoFTR(args.cfg, args.weights, args.matcher, args.suppression, args.top_k, args.max_length)
```
这里调用了 `LoFTR` 类,传入参数表示加载指定的配置文件、权重文件、特征匹配器、特征抑制器、保留的关键点数量和序列的最大长度。该类主要包含以下方法:
- `__init__(self, cfg, weights, matcher, suppression, top_k, max_length)`:初始化函数,加载模型权重和配置文件。
- `extract_features(self, img)`:提取图像的局部特征,返回值是一个元组 `(keypoints, descriptors)`,其中 `keypoints` 是关键点坐标,`descriptors` 是关键点特征描述子。
- `match(self, ref_feats, query_feats)`:在参考图像和查询图像的局部特征之间进行匹配,返回值是一个元组 `(matches_ref, matches_query)`,其中 `matches_ref` 是参考图像中的匹配点坐标,`matches_query` 是查询图像中的匹配点坐标。
- `reconstruct(self, ref_img, ref_feats, query_img, query_feats, matches)`:利用 LoFTR 算法进行三维重建,返回值是一个元组 `(R, t, pts_ref, pts_query, pts_3d)`,其中 `R` 和 `t` 是参考图像到查询图像的旋转矩阵和平移向量,`pts_ref` 和 `pts_query` 是参考图像和查询图像中的匹配点坐标,`pts_3d` 是三维重建得到的点云坐标。
接下来是提取图像的局部特征:
```python
timer = Timer()
timer.start()
kpts0, desc0 = loftr.extract_features(img0)
kpts1, desc1 = loftr.extract_features(img1)
timer.stop('extract features')
```
这里调用了 `extract_features` 方法,传入参数是加载的 LoFTR 模型和图像张量,返回值是两个元组 `(keypoints, descriptors)`,分别表示两幅图像的关键点坐标和特征描述子。这里还使用了 `Timer` 类来统计函数运行时间,方便后面的性能评估。
然后是在两幅图像之间进行特征匹配:
```python
timer.start()
matches0, matches1 = loftr.match((kpts0, desc0), (kpts1, desc1))
timer.stop('match features')
```
这里调用了 `match` 方法,传入参数是两个元组 `(keypoints, descriptors)`,分别表示参考图像和查询图像的关键点坐标和特征描述子。返回值也是两个元组 `(matches_ref, matches_query)`,分别表示参考图像和查询图像中的匹配点坐标。这里同样使用了 `Timer` 类来统计函数运行时间。
接下来是在两幅图像之间进行三维重建:
```python
timer.start()
R, t, pts0, pts1, pts3d = loftr.reconstruct(img0, (kpts0, desc0), img1, (kpts1, desc1), (matches0, matches1))
timer.stop('reconstruct 3D')
```
这里调用了 `reconstruct` 方法,传入参数是参考图像、参考图像的局部特征、查询图像、查询图像的局部特征和两幅图像之间的匹配点坐标。返回值是一个元组 `(R, t, pts0, pts1, pts3d)`,分别表示参考图像到查询图像的旋转矩阵和平移向量,两幅图像中的匹配点坐标和三维重建得到的点云坐标。同样使用了 `Timer` 类来统计函数运行时间。
最后是对结果进行可视化和保存:
```python
if args.show or args.eval:
plot_keypoints(img0, kpts0, title="Image 0 keypoints")
plot_keypoints(img1, kpts1, title="Image 1 keypoints")
plot_matches(img0, img1, matches0, matches1, title="Matches", savepath=os.path.join(args.output_dir, "matches.png"))
plot_reprojection(pts0, pts1, pts3d, R, t, title="Reprojection", savepath=os.path.join(args.output_dir, "reprojection.png"))
plot_trajectory(pts3d, title="Trajectory", savepath=os.path.join(args.output_dir, "trajectory.png"))
save_trajectory(os.path.join(args.output_dir, "trajectory.txt"), pts3d)
if args.eval:
repeatability = compute_repeatability(kpts0, kpts1, matches0, matches1)
precision, recall = compute_precision_recall(kpts0, kpts1, matches0, matches1, pts3d)
print(f"Repeatability: {repeatability:.4f}")
print(f"Precision@{len(matches0)}: {precision:.4f}")
print(f"Recall@{len(matches0)}: {recall:.4f}")
```
这里根据命令行参数设置是否显示和保存可视化结果,以及是否评估匹配和重建性能。如果需要显示和保存结果,则调用 `plot_keypoints`、`plot_matches`、`plot_reprojection` 和 `plot_trajectory` 函数生成相应的图像,保存到指定目录。如果需要评估性能,则调用 `compute_repeatability` 和 `compute_precision_recall` 函数计算重复性、精度和召回率,输出结果。最后还调用了 `save_trajectory` 函数将重建得到的点云坐标保存到文件中。
阅读全文