loftr代码详解demo.py详解
时间: 2023-07-03 21:04:32 浏览: 241
安卓应用-社交聊天-Lofter手机版 v6.18.1.zip
demo.py 是 LoFTR 算法的主要测试文件,用于演示 LoFTR 算法在图像匹配和三维重建任务上的效果,下面我将对其中的关键代码进行详细解释。
首先是导入必要的包和模块:
```python
import os
import sys
import time
import argparse
import numpy as np
import cv2
import torch
import matplotlib.pyplot as plt
from utils.tools import (
Timer,
str2bool,
load_image,
plot_keypoints,
plot_matches,
plot_reprojection,
plot_trajectory,
save_trajectory,
)
from utils.evaluation import compute_repeatability, compute_precision_recall
from models.loftr import LoFTR, default_cfg
```
其中 `os`、`sys`、`time`、`argparse`、`numpy`、`cv2`、`torch`、`matplotlib.pyplot` 都是 Python 常用的标准库或第三方库,此处不再赘述。`plot_keypoints`、`plot_matches`、`plot_reprojection` 和 `plot_trajectory` 是自定义的用于可视化结果的函数。`compute_repeatability` 和 `compute_precision_recall` 是用于评估匹配和重建结果的函数,详见 `evaluation.py` 文件。`LoFTR` 是 LoFTR 模型的主要类,`default_cfg` 是 LoFTR 模型的默认配置。
然后是解析命令行参数:
```python
parser = argparse.ArgumentParser(
description="LoFTR: Detector-Suppressor for 3D Local Features",
formatter_class=argparse.ArgumentDefaultsHelpFormatter,
)
parser.add_argument("img0_path", type=str, help="path to the reference image")
parser.add_argument("img1_path", type=str, help="path to the query image")
parser.add_argument("--weights", type=str, default=default_cfg["weights"], help="path to the pretrained model weights")
parser.add_argument("--cfg", type=str, default=default_cfg["cfg"], help="path to the config file")
parser.add_argument("--matcher", type=str, default=default_cfg["matcher"], help="feature matcher")
parser.add_argument("--suppression", type=str, default=default_cfg["suppression"], help="feature suppressor")
parser.add_argument("--top_k", type=int, default=default_cfg["top_k"], help="keep top_k keypoints")
parser.add_argument("--max_length", type=int, default=default_cfg["max_length"], help="maximum sequence length")
parser.add_argument("--resize", type=str2bool, nargs="?", const=True, default=True, help="resize input images")
parser.add_argument("--show", type=str2bool, nargs="?", const=True, default=False, help="show results")
parser.add_argument(
"--eval", type=str2bool, nargs="?", const=True, default=False, help="evaluate repeatability and matching performance"
)
parser.add_argument("--output_dir", type=str, default="outputs", help="output directory")
args = parser.parse_args()
```
其中 `img0_path` 和 `img1_path` 分别表示参考图像和查询图像的路径。`weights`、`cfg`、`matcher`、`suppression`、`top_k`、`max_length` 分别表示 LoFTR 模型的权重文件、配置文件、特征匹配器、特征抑制器、保留的关键点数量、序列的最大长度。`resize`、`show`、`eval`、`output_dir` 分别表示是否对输入图像进行缩放、是否显示结果、是否评估性能、输出结果的目录。
接下来是读取图像并将其转换为张量:
```python
img0 = load_image(args.img0_path, resize=args.resize)
img1 = load_image(args.img1_path, resize=args.resize)
```
其中 `load_image` 函数用于加载图像,将其转换为 BGR 格式,并将其缩放到固定大小,返回值是一个 `torch.Tensor` 对象。
然后是加载 LoFTR 模型:
```python
loftr = LoFTR(args.cfg, args.weights, args.matcher, args.suppression, args.top_k, args.max_length)
```
这里调用了 `LoFTR` 类,传入参数表示加载指定的配置文件、权重文件、特征匹配器、特征抑制器、保留的关键点数量和序列的最大长度。该类主要包含以下方法:
- `__init__(self, cfg, weights, matcher, suppression, top_k, max_length)`:初始化函数,加载模型权重和配置文件。
- `extract_features(self, img)`:提取图像的局部特征,返回值是一个元组 `(keypoints, descriptors)`,其中 `keypoints` 是关键点坐标,`descriptors` 是关键点特征描述子。
- `match(self, ref_feats, query_feats)`:在参考图像和查询图像的局部特征之间进行匹配,返回值是一个元组 `(matches_ref, matches_query)`,其中 `matches_ref` 是参考图像中的匹配点坐标,`matches_query` 是查询图像中的匹配点坐标。
- `reconstruct(self, ref_img, ref_feats, query_img, query_feats, matches)`:利用 LoFTR 算法进行三维重建,返回值是一个元组 `(R, t, pts_ref, pts_query, pts_3d)`,其中 `R` 和 `t` 是参考图像到查询图像的旋转矩阵和平移向量,`pts_ref` 和 `pts_query` 是参考图像和查询图像中的匹配点坐标,`pts_3d` 是三维重建得到的点云坐标。
接下来是提取图像的局部特征:
```python
timer = Timer()
timer.start()
kpts0, desc0 = loftr.extract_features(img0)
kpts1, desc1 = loftr.extract_features(img1)
timer.stop('extract features')
```
这里调用了 `extract_features` 方法,传入参数是加载的 LoFTR 模型和图像张量,返回值是两个元组 `(keypoints, descriptors)`,分别表示两幅图像的关键点坐标和特征描述子。这里还使用了 `Timer` 类来统计函数运行时间,方便后面的性能评估。
然后是在两幅图像之间进行特征匹配:
```python
timer.start()
matches0, matches1 = loftr.match((kpts0, desc0), (kpts1, desc1))
timer.stop('match features')
```
这里调用了 `match` 方法,传入参数是两个元组 `(keypoints, descriptors)`,分别表示参考图像和查询图像的关键点坐标和特征描述子。返回值也是两个元组 `(matches_ref, matches_query)`,分别表示参考图像和查询图像中的匹配点坐标。这里同样使用了 `Timer` 类来统计函数运行时间。
接下来是在两幅图像之间进行三维重建:
```python
timer.start()
R, t, pts0, pts1, pts3d = loftr.reconstruct(img0, (kpts0, desc0), img1, (kpts1, desc1), (matches0, matches1))
timer.stop('reconstruct 3D')
```
这里调用了 `reconstruct` 方法,传入参数是参考图像、参考图像的局部特征、查询图像、查询图像的局部特征和两幅图像之间的匹配点坐标。返回值是一个元组 `(R, t, pts0, pts1, pts3d)`,分别表示参考图像到查询图像的旋转矩阵和平移向量,两幅图像中的匹配点坐标和三维重建得到的点云坐标。同样使用了 `Timer` 类来统计函数运行时间。
最后是对结果进行可视化和保存:
```python
if args.show or args.eval:
plot_keypoints(img0, kpts0, title="Image 0 keypoints")
plot_keypoints(img1, kpts1, title="Image 1 keypoints")
plot_matches(img0, img1, matches0, matches1, title="Matches", savepath=os.path.join(args.output_dir, "matches.png"))
plot_reprojection(pts0, pts1, pts3d, R, t, title="Reprojection", savepath=os.path.join(args.output_dir, "reprojection.png"))
plot_trajectory(pts3d, title="Trajectory", savepath=os.path.join(args.output_dir, "trajectory.png"))
save_trajectory(os.path.join(args.output_dir, "trajectory.txt"), pts3d)
if args.eval:
repeatability = compute_repeatability(kpts0, kpts1, matches0, matches1)
precision, recall = compute_precision_recall(kpts0, kpts1, matches0, matches1, pts3d)
print(f"Repeatability: {repeatability:.4f}")
print(f"Precision@{len(matches0)}: {precision:.4f}")
print(f"Recall@{len(matches0)}: {recall:.4f}")
```
这里根据命令行参数设置是否显示和保存可视化结果,以及是否评估匹配和重建性能。如果需要显示和保存结果,则调用 `plot_keypoints`、`plot_matches`、`plot_reprojection` 和 `plot_trajectory` 函数生成相应的图像,保存到指定目录。如果需要评估性能,则调用 `compute_repeatability` 和 `compute_precision_recall` 函数计算重复性、精度和召回率,输出结果。最后还调用了 `save_trajectory` 函数将重建得到的点云坐标保存到文件中。
阅读全文