model.eval详解
时间: 2023-11-09 07:25:04 浏览: 175
model.eval() 是 PyTorch 中用于将模型设置为评估模式的函数。在评估模式下,模型的行为会略有差异,主要体现在两个方面:
1. Batch Normalization 和 Dropout 操作会失效,并且会使用已保存的移动平均值进行归一化。
2. 在推理过程中,不需要计算并保存梯度,因为我们不需要对模型进行参数更新。
这些差异是因为在评估模式下,我们不需要进行反向传播计算梯度,只需要基于模型的输入和参数直接输出相应的评估结果即可。
因此,在对模型进行评估时,我们一般会用 model.eval() 临时将模型设置为评估模式,以确保所得到的结果是可靠的。
相关问题
model.train(0和model.eval()的区别
model.train()和model.eval()是pytorch中用于控制模型训练状态的方法。model.train()将模型设置为训练模式,而model.eval()将模型设置为评估模式。
在训练过程中,model.train()会启用Batch Normalization层(BN层)和Dropout层的计算,以便在每个batch的训练过程中进行正则化和随机失活。同时,它还会更新模型的参数,使其适应训练数据。
相反,model.eval()会将模型设置为评估模式,此时模型不会进行BN层和Dropout层的计算,因为在评估阶段不需要进行正则化和随机失活。此外,模型的参数也不会更新,因为评估阶段只是用来测试模型在新数据上的性能。
需要注意的是,使用model.eval()之后,需要手动使用torch.no_grad()上下文管理器来禁止梯度的计算。torch.no_grad()会包裹住的代码块不会被追踪梯度,也就是说不会记录计算过程,不能进行反向传播更新参数。
综上所述,model.train()用于模型训练阶段,开启BN层和Dropout层的计算并更新参数,而model.eval()用于模型评估阶段,关闭BN层和Dropout层的计算并不更新参数。<span class="em">1</span><span class="em">2</span><span class="em">3</span>
#### 引用[.reference_title]
- *1* *3* [pytorch中model.train和model.eval](https://blog.csdn.net/dagouxiaohui/article/details/125620786)[target="_blank" data-report-click={"spm":"1018.2226.3001.9630","extra":{"utm_source":"vip_chatgpt_common_search_pc_result","utm_medium":"distribute.pc_search_result.none-task-cask-2~all~insert_cask~default-1-null.142^v93^chatsearchT3_1"}}] [.reference_item style="max-width: 50%"]
- *2* [pytorch:model.train和model.eval用法及区别详解](https://download.csdn.net/download/weixin_38611254/12855267)[target="_blank" data-report-click={"spm":"1018.2226.3001.9630","extra":{"utm_source":"vip_chatgpt_common_search_pc_result","utm_medium":"distribute.pc_search_result.none-task-cask-2~all~insert_cask~default-1-null.142^v93^chatsearchT3_1"}}] [.reference_item style="max-width: 50%"]
[ .reference_list ]
loftr代码详解demo.py详解
demo.py 是 LoFTR 算法的主要测试文件,用于演示 LoFTR 算法在图像匹配和三维重建任务上的效果,下面我将对其中的关键代码进行详细解释。
首先是导入必要的包和模块:
```python
import os
import sys
import time
import argparse
import numpy as np
import cv2
import torch
import matplotlib.pyplot as plt
from utils.tools import (
Timer,
str2bool,
load_image,
plot_keypoints,
plot_matches,
plot_reprojection,
plot_trajectory,
save_trajectory,
)
from utils.evaluation import compute_repeatability, compute_precision_recall
from models.loftr import LoFTR, default_cfg
```
其中 `os`、`sys`、`time`、`argparse`、`numpy`、`cv2`、`torch`、`matplotlib.pyplot` 都是 Python 常用的标准库或第三方库,此处不再赘述。`plot_keypoints`、`plot_matches`、`plot_reprojection` 和 `plot_trajectory` 是自定义的用于可视化结果的函数。`compute_repeatability` 和 `compute_precision_recall` 是用于评估匹配和重建结果的函数,详见 `evaluation.py` 文件。`LoFTR` 是 LoFTR 模型的主要类,`default_cfg` 是 LoFTR 模型的默认配置。
然后是解析命令行参数:
```python
parser = argparse.ArgumentParser(
description="LoFTR: Detector-Suppressor for 3D Local Features",
formatter_class=argparse.ArgumentDefaultsHelpFormatter,
)
parser.add_argument("img0_path", type=str, help="path to the reference image")
parser.add_argument("img1_path", type=str, help="path to the query image")
parser.add_argument("--weights", type=str, default=default_cfg["weights"], help="path to the pretrained model weights")
parser.add_argument("--cfg", type=str, default=default_cfg["cfg"], help="path to the config file")
parser.add_argument("--matcher", type=str, default=default_cfg["matcher"], help="feature matcher")
parser.add_argument("--suppression", type=str, default=default_cfg["suppression"], help="feature suppressor")
parser.add_argument("--top_k", type=int, default=default_cfg["top_k"], help="keep top_k keypoints")
parser.add_argument("--max_length", type=int, default=default_cfg["max_length"], help="maximum sequence length")
parser.add_argument("--resize", type=str2bool, nargs="?", const=True, default=True, help="resize input images")
parser.add_argument("--show", type=str2bool, nargs="?", const=True, default=False, help="show results")
parser.add_argument(
"--eval", type=str2bool, nargs="?", const=True, default=False, help="evaluate repeatability and matching performance"
)
parser.add_argument("--output_dir", type=str, default="outputs", help="output directory")
args = parser.parse_args()
```
其中 `img0_path` 和 `img1_path` 分别表示参考图像和查询图像的路径。`weights`、`cfg`、`matcher`、`suppression`、`top_k`、`max_length` 分别表示 LoFTR 模型的权重文件、配置文件、特征匹配器、特征抑制器、保留的关键点数量、序列的最大长度。`resize`、`show`、`eval`、`output_dir` 分别表示是否对输入图像进行缩放、是否显示结果、是否评估性能、输出结果的目录。
接下来是读取图像并将其转换为张量:
```python
img0 = load_image(args.img0_path, resize=args.resize)
img1 = load_image(args.img1_path, resize=args.resize)
```
其中 `load_image` 函数用于加载图像,将其转换为 BGR 格式,并将其缩放到固定大小,返回值是一个 `torch.Tensor` 对象。
然后是加载 LoFTR 模型:
```python
loftr = LoFTR(args.cfg, args.weights, args.matcher, args.suppression, args.top_k, args.max_length)
```
这里调用了 `LoFTR` 类,传入参数表示加载指定的配置文件、权重文件、特征匹配器、特征抑制器、保留的关键点数量和序列的最大长度。该类主要包含以下方法:
- `__init__(self, cfg, weights, matcher, suppression, top_k, max_length)`:初始化函数,加载模型权重和配置文件。
- `extract_features(self, img)`:提取图像的局部特征,返回值是一个元组 `(keypoints, descriptors)`,其中 `keypoints` 是关键点坐标,`descriptors` 是关键点特征描述子。
- `match(self, ref_feats, query_feats)`:在参考图像和查询图像的局部特征之间进行匹配,返回值是一个元组 `(matches_ref, matches_query)`,其中 `matches_ref` 是参考图像中的匹配点坐标,`matches_query` 是查询图像中的匹配点坐标。
- `reconstruct(self, ref_img, ref_feats, query_img, query_feats, matches)`:利用 LoFTR 算法进行三维重建,返回值是一个元组 `(R, t, pts_ref, pts_query, pts_3d)`,其中 `R` 和 `t` 是参考图像到查询图像的旋转矩阵和平移向量,`pts_ref` 和 `pts_query` 是参考图像和查询图像中的匹配点坐标,`pts_3d` 是三维重建得到的点云坐标。
接下来是提取图像的局部特征:
```python
timer = Timer()
timer.start()
kpts0, desc0 = loftr.extract_features(img0)
kpts1, desc1 = loftr.extract_features(img1)
timer.stop('extract features')
```
这里调用了 `extract_features` 方法,传入参数是加载的 LoFTR 模型和图像张量,返回值是两个元组 `(keypoints, descriptors)`,分别表示两幅图像的关键点坐标和特征描述子。这里还使用了 `Timer` 类来统计函数运行时间,方便后面的性能评估。
然后是在两幅图像之间进行特征匹配:
```python
timer.start()
matches0, matches1 = loftr.match((kpts0, desc0), (kpts1, desc1))
timer.stop('match features')
```
这里调用了 `match` 方法,传入参数是两个元组 `(keypoints, descriptors)`,分别表示参考图像和查询图像的关键点坐标和特征描述子。返回值也是两个元组 `(matches_ref, matches_query)`,分别表示参考图像和查询图像中的匹配点坐标。这里同样使用了 `Timer` 类来统计函数运行时间。
接下来是在两幅图像之间进行三维重建:
```python
timer.start()
R, t, pts0, pts1, pts3d = loftr.reconstruct(img0, (kpts0, desc0), img1, (kpts1, desc1), (matches0, matches1))
timer.stop('reconstruct 3D')
```
这里调用了 `reconstruct` 方法,传入参数是参考图像、参考图像的局部特征、查询图像、查询图像的局部特征和两幅图像之间的匹配点坐标。返回值是一个元组 `(R, t, pts0, pts1, pts3d)`,分别表示参考图像到查询图像的旋转矩阵和平移向量,两幅图像中的匹配点坐标和三维重建得到的点云坐标。同样使用了 `Timer` 类来统计函数运行时间。
最后是对结果进行可视化和保存:
```python
if args.show or args.eval:
plot_keypoints(img0, kpts0, title="Image 0 keypoints")
plot_keypoints(img1, kpts1, title="Image 1 keypoints")
plot_matches(img0, img1, matches0, matches1, title="Matches", savepath=os.path.join(args.output_dir, "matches.png"))
plot_reprojection(pts0, pts1, pts3d, R, t, title="Reprojection", savepath=os.path.join(args.output_dir, "reprojection.png"))
plot_trajectory(pts3d, title="Trajectory", savepath=os.path.join(args.output_dir, "trajectory.png"))
save_trajectory(os.path.join(args.output_dir, "trajectory.txt"), pts3d)
if args.eval:
repeatability = compute_repeatability(kpts0, kpts1, matches0, matches1)
precision, recall = compute_precision_recall(kpts0, kpts1, matches0, matches1, pts3d)
print(f"Repeatability: {repeatability:.4f}")
print(f"Precision@{len(matches0)}: {precision:.4f}")
print(f"Recall@{len(matches0)}: {recall:.4f}")
```
这里根据命令行参数设置是否显示和保存可视化结果,以及是否评估匹配和重建性能。如果需要显示和保存结果,则调用 `plot_keypoints`、`plot_matches`、`plot_reprojection` 和 `plot_trajectory` 函数生成相应的图像,保存到指定目录。如果需要评估性能,则调用 `compute_repeatability` 和 `compute_precision_recall` 函数计算重复性、精度和召回率,输出结果。最后还调用了 `save_trajectory` 函数将重建得到的点云坐标保存到文件中。
阅读全文