def test(): # 加载测试数据 test_loader = torch.utils.data.DataLoader( datasets.CIFAR10(root = args.data, train=False, transform=transforms.Compose([ transforms.ToTensor(), transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))])), batch_size=8, shuffle=False, num_workers=0) model.eval() correct = 0
时间: 2023-07-15 18:14:01 浏览: 117
这段代码定义了一个名为 `test` 的函数,该函数用于测试训练好的模型在 CIFAR10 数据集上的准确率。
首先,通过 `torch.utils.data.DataLoader` 加载 CIFAR10 数据集,并进行了数据预处理,包括将图片转化为张量,并进行归一化处理。
然后,将模型设置为评估模式,即 `model.eval()`。
最后,定义了一个变量 `correct`,用于存储分类正确的样本数。
相关问题
import torch, os, cv2 from model.model import parsingNet from utils.common import merge_config from utils.dist_utils import dist_print import torch import scipy.special, tqdm import numpy as np import torchvision.transforms as transforms from data.dataset import LaneTestDataset from data.constant import culane_row_anchor, tusimple_row_anchor if __name__ == "__main__": torch.backends.cudnn.benchmark = True args, cfg = merge_config() dist_print('start testing...') assert cfg.backbone in ['18','34','50','101','152','50next','101next','50wide','101wide'] if cfg.dataset == 'CULane': cls_num_per_lane = 18 elif cfg.dataset == 'Tusimple': cls_num_per_lane = 56 else: raise NotImplementedError net = parsingNet(pretrained = False, backbone=cfg.backbone,cls_dim = (cfg.griding_num+1,cls_num_per_lane,4), use_aux=False).cuda() # we dont need auxiliary segmentation in testing state_dict = torch.load(cfg.test_model, map_location='cpu')['model'] compatible_state_dict = {} for k, v in state_dict.items(): if 'module.' in k: compatible_state_dict[k[7:]] = v else: compatible_state_dict[k] = v net.load_state_dict(compatible_state_dict, strict=False) net.eval() img_transforms = transforms.Compose([ transforms.Resize((288, 800)), transforms.ToTensor(), transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)), ]) if cfg.dataset == 'CULane': splits = ['test0_normal.txt', 'test1_crowd.txt', 'test2_hlight.txt', 'test3_shadow.txt', 'test4_noline.txt', 'test5_arrow.txt', 'test6_curve.txt', 'test7_cross.txt', 'test8_night.txt'] datasets = [LaneTestDataset(cfg.data_root,os.path.join(cfg.data_root, 'list/test_split/'+split),img_transform = img_transforms) for split in splits] img_w, img_h = 1640, 590 row_anchor = culane_row_anchor elif cfg.dataset == 'Tusimple': splits = ['test.txt'] datasets = [LaneTestDataset(cfg.data_root,os.path.join(cfg.data_root, split),img_transform = img_transforms) for split in splits] img_w, img_h = 1280, 720 row_anchor = tusimple_row_anchor else: raise NotImplementedError for split, dataset in zip(splits, datasets): loader = torch.utils.data.DataLoader(dataset, batch_size=1, shuffle = False, num_workers=1) fourcc = cv2.VideoWriter_fourcc(*'MJPG') print(split[:-3]+'avi') vout = cv2.VideoWriter(split[:-3]+'avi', fourcc , 30.0, (img_w, img_h)) for i, data in enumerate(tqdm.tqdm(loader)): imgs, names = data imgs = imgs.cuda() with torch.no_grad(): out = net(imgs) col_sample = np.linspace(0, 800 - 1, cfg.griding_num) col_sample_w = col_sample[1] - col_sample[0] out_j = out[0].data.cpu().numpy() out_j = out_j[:, ::-1, :] prob = scipy.special.softmax(out_j[:-1, :, :], axis=0) idx = np.arange(cfg.griding_num) + 1 idx = idx.reshape(-1, 1, 1) loc = np.sum(prob * idx, axis=0) out_j = np.argmax(out_j, axis=0) loc[out_j == cfg.griding_num] = 0 out_j = loc # import pdb; pdb.set_trace() vis = cv2.imread(os.path.join(cfg.data_root,names[0])) for i in range(out_j.shape[1]): if np.sum(out_j[:, i] != 0) > 2: for k in range(out_j.shape[0]): if out_j[k, i] > 0: ppp = (int(out_j[k, i] * col_sample_w * img_w / 800) - 1, int(img_h * (row_anchor[cls_num_per_lane-1-k]/288)) - 1 ) cv2.circle(vis,ppp,5,(0,255,0),-1) vout.write(vis) vout.release()
这是一个使用PyTorch框架测试模型的代码,代码中导入了许多必要的库以及自定义的模型和数据集等。通过调用PyTorch的后端,启用一些加速技术,然后解析配置参数,并根据数据集类型设置类别数。创建一个模型实例,传入必要的参数并将其转移到GPU上以加速处理。最后进行模型测试。
import argparse import json import os import sys from pathlib import Path from threading import Thread import numpy as np import torch from tqdm import tqdm FILE = Path(__file__).resolve() ROOT = FILE.parents[0] # YOLOv5 root directory if str(ROOT) not in sys.path: sys.path.append(str(ROOT)) # add ROOT to PATH ROOT = Path(os.path.relpath(ROOT, Path.cwd())) # relative from models.common import DetectMultiBackend from utils.callbacks import Callbacks from utils.datasets import create_dataloader from utils.general import (LOGGER, NCOLS, box_iou, check_dataset, check_img_size, check_requirements, check_yaml, coco80_to_coco91_class, colorstr, increment_path, non_max_suppression, print_args, scale_coords, xywh2xyxy, xyxy2xywh) from utils.metrics import ConfusionMatrix, ap_per_class from utils.plots import output_to_target, plot_images, plot_val_study from utils.torch_utils import select_device, time_sync
这段代码主要是导入了一些Python库和自定义的模块,以及定义了一些全局变量,主要功能包括:
- argparse: 用于解析命令行参数。
- json: 用于解析JSON格式的数据。
- os: 提供了一些与操作系统交互的函数。
- sys: 提供了一些与Python解释器交互的函数。
- pathlib: 提供了一种面向对象的方式来操作文件路径。
- threading: 用于创建多线程程序。
- numpy: 用于支持大型、多维数组和矩阵运算。
- torch: 用于实现深度学习模型。
- tqdm: 用于显示进度条。
除此之外,还导入了自定义的模块,包括:
- DetectMultiBackend: 实现了多种目标检测算法的后端。
- Callbacks: 实现了一些回调函数,用于在训练过程中进行一些操作。
- create_dataloader: 用于创建数据加载器。
- check_dataset: 用于检查数据集。
- check_img_size: 用于检查图像大小。
- check_requirements: 用于检查是否满足运行要求。
- check_yaml: 用于检查YAML文件。
- coco80_to_coco91_class: 用于将COCO80类别映射到COCO91类别。
- colorstr: 用于生成颜色字符串。
- increment_path: 用于生成不重复的路径。
- non_max_suppression: 用于实现非极大值抑制。
- print_args: 用于打印命令行参数。
- scale_coords: 用于将坐标缩放。
- xywh2xyxy: 用于将中心坐标和宽高转换为左上角和右下角坐标。
- xyxy2xywh: 用于将左上角和右下角坐标转换为中心坐标和宽高。
其中,FILE定义了当前文件的路径,ROOT定义了YOLOv5的根目录,如果ROOT不在sys.path中,则将其添加到PATH中,并将其转换为相对路径。
阅读全文