在python3.6,pytorch1.10.2,cuda11.3,numpy1.19.5环境下,完成一个名为yolov7的类,实现本地加载用自己数据集训练的yolov5的.pth模型,对图片进行检测并以列表的形式输出类别以及检测框的四个顶点位置,写成函数的形式调用

时间: 2023-06-18 16:02:39 浏览: 79
首先,需要安装 `torch` 和 `opencv-python` 库。然后,可以使用以下代码实现你所需的功能: ```python import torch import cv2 import numpy as np class YOLOv7: def __init__(self, model_path, device='cuda'): self.model = torch.load(model_path, map_location=device)['model'].float() self.model.to(device).eval() self.device = device self.anchors = torch.tensor([[10,13], [16,30], [33,23], [30,61], [62,45], [59,119], [116,90], [156,198], [373,326]]).to(device) self.stride = torch.tensor([8, 16, 32]).to(device) self.grid_size = 0 self.img_size = 0 def detect(self, img): self.img_size = img.shape[1], img.shape[0] img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) img = cv2.resize(img, (640, 640), interpolation=cv2.INTER_LINEAR) img = img.astype(np.float32) / 255. img = torch.from_numpy(img).unsqueeze(0).permute(0, 3, 1, 2).to(self.device) self.grid_size = img.shape[2] // self.stride with torch.no_grad(): pred = self.model(img) outputs = self.postprocess(pred) return outputs def postprocess(self, pred): outputs = [] for i, p in enumerate(pred): if i == 0: anchor_idx = [3, 4, 5] else: anchor_idx = [0, 1, 2] grid_size = p.shape[2] stride = self.img_size[0] // grid_size scaled_anchors = self.anchors[anchor_idx] / stride prediction = self.decode(p, scaled_anchors) prediction[..., :4] *= stride outputs.append(prediction) outputs = torch.cat(outputs, 1) return self.non_max_suppression(outputs) def decode(self, pred, anchors): batch_size, _, grid_size, _ = pred.shape pred = pred.view(batch_size, 3, -1, grid_size, grid_size).permute(0, 1, 3, 4, 2).contiguous() x, y, w, h, obj, cls = torch.split(pred, [1, 1, 1, 1, 1, -1], dim=-1) x = torch.sigmoid(x) y = torch.sigmoid(y) obj = torch.sigmoid(obj) cls = torch.sigmoid(cls) grid_y, grid_x = torch.meshgrid(torch.arange(grid_size), torch.arange(grid_size)) xy_grid = torch.stack((grid_x, grid_y), dim=-1).to(self.device).float() xy_grid = xy_grid.view(1, 1, grid_size, grid_size, 2) xy_grid = xy_grid.repeat(batch_size, 3, 1, 1, 1) x += xy_grid[..., 0:1] y += xy_grid[..., 1:2] anchors = anchors.view(1, 3, 1, 1, 2).repeat(batch_size, 1, grid_size, grid_size, 1) w = torch.exp(w) * anchors[..., 0:1] h = torch.exp(h) * anchors[..., 1:2] x1 = x - w / 2 y1 = y - h / 2 x2 = x1 + w y2 = y1 + h prediction = torch.cat((x1, y1, x2, y2, obj, cls), dim=-1) return prediction.view(batch_size, -1, 6) def non_max_suppression(self, prediction): output = [] for i, image_pred in enumerate(prediction): # Filter out confidence scores below threshold conf_mask = (image_pred[:, 4] >= 0.5).squeeze() image_pred = image_pred[conf_mask] # If none are remaining => process next image if not image_pred.size(0): continue # Object confidence times class confidence score = image_pred[:, 4] * image_pred[:, 5:].max(1)[0] # Sort by it image_pred = image_pred[(-score).argsort()] class_confs, class_preds = image_pred[:, 5:].max(1, keepdim=True) detections = torch.cat((image_pred[:, :5], class_confs.float(), class_preds.float()), 1) # Iterate over detections for c in detections[:, -1].unique(): detections_class = detections[detections[:, -1] == c] # Sort by score keep = torch.tensor([], dtype=torch.long) while detections_class.size(0): large_overlap = self.bbox_iou(detections_class[:1, :4], detections_class[:, :4]) > 0.5 label_match = detections_class[0, -1] == detections_class[:, -1] # Indices of boxes with lower confidence scores, large IOUs and matching labels invalid = large_overlap & label_match keep = torch.cat((keep, detections_class[:1].long()), dim=0) detections_class = detections_class[~invalid] detections_class = detections[keep] # Append detections for this image output.extend(detections_class.tolist()) return output def bbox_iou(self, box1, box2): """ Returns the IoU of two bounding boxes """ box1_area = (box1[:, 2] - box1[:, 0]) * (box1[:, 3] - box1[:, 1]) box2_area = (box2[:, 2] - box2[:, 0]) * (box2[:, 3] - box2[:, 1]) inter_min = torch.max(box1[:, None, :2], box2[:, :2]) inter_max = torch.min(box1[:, None, 2:], box2[:, 2:]) inter_size = torch.clamp((inter_max - inter_min), min=0) inter_area = inter_size[:, :, 0] * inter_size[:, :, 1] iou = inter_area / (box1_area[:, None] + box2_area - inter_area) return iou ``` 然后,可以使用以下代码调用该类: ```python model_path = 'path/to/your/yolov5.pth' yolov7 = YOLOv7(model_path) img_path = 'path/to/your/image.jpg' img = cv2.imread(img_path) outputs = yolov7.detect(img) print(outputs) ``` 输出的 `outputs` 是一个列表,其中每个元素都是一个检测框的信息,包括类别、置信度和四个顶点位置。

相关推荐

最新推荐

recommend-type

PyTorch官方教程中文版.pdf

Py Torch是一个基于 Torch的 Python开源机器学习库,用于自然语言处理等应用程序。它主要由Facebook的人工智能小组开发,不仅能够实现强大的GPU加速,同时还支持动态神经网络,这点是现在很多主流框架如 TensorFlow...
recommend-type

pytorch 中pad函数toch.nn.functional.pad()的用法

今天小编就为大家分享一篇pytorch 中pad函数toch.nn.functional.pad()的用法,具有很好的参考价值,希望对大家有所帮助。一起跟随小编过来看看吧
recommend-type

Anaconda+Pycharm环境下的PyTorch配置方法

写给新手的话 pycharm是什么,为什么让我指定interpreter 记事本 ...这要是编写、运行、调试都能在同一个窗口里进行,再来点语法检查,高亮,颜色,代码提示,那写代码的效率不就高多了吗?所以就有了
recommend-type

pytorch实现mnist分类的示例讲解

今天小编就为大家分享一篇pytorch实现mnist分类的示例讲解,具有很好的参考价值,希望对大家有所帮助。一起跟随小编过来看看吧
recommend-type

pytorch 实现数据增强分类 albumentations的使用

albumentations包是一种针对数据增强专门写的API,里面基本包含大量的数据增强手段,比起pytorch自带的ttransform更丰富,搭配使用效果更好。 代码和效果 import albumentations import cv2 from PIL import Image, ...
recommend-type

zigbee-cluster-library-specification

最新的zigbee-cluster-library-specification说明文档。
recommend-type

管理建模和仿真的文件

管理Boualem Benatallah引用此版本:布阿利姆·贝纳塔拉。管理建模和仿真。约瑟夫-傅立叶大学-格勒诺布尔第一大学,1996年。法语。NNT:电话:00345357HAL ID:电话:00345357https://theses.hal.science/tel-003453572008年12月9日提交HAL是一个多学科的开放存取档案馆,用于存放和传播科学研究论文,无论它们是否被公开。论文可以来自法国或国外的教学和研究机构,也可以来自公共或私人研究中心。L’archive ouverte pluridisciplinaire
recommend-type

实现实时数据湖架构:Kafka与Hive集成

![实现实时数据湖架构:Kafka与Hive集成](https://img-blog.csdnimg.cn/img_convert/10eb2e6972b3b6086286fc64c0b3ee41.jpeg) # 1. 实时数据湖架构概述** 实时数据湖是一种现代数据管理架构,它允许企业以低延迟的方式收集、存储和处理大量数据。与传统数据仓库不同,实时数据湖不依赖于预先定义的模式,而是采用灵活的架构,可以处理各种数据类型和格式。这种架构为企业提供了以下优势: - **实时洞察:**实时数据湖允许企业访问最新的数据,从而做出更明智的决策。 - **数据民主化:**实时数据湖使各种利益相关者都可
recommend-type

SPDK_NVMF_DISCOVERY_NQN是什么 有什么作用

SPDK_NVMF_DISCOVERY_NQN 是 SPDK (Storage Performance Development Kit) 中用于查询 NVMf (Non-Volatile Memory express over Fabrics) 存储设备名称的协议。NVMf 是一种基于网络的存储协议,可用于连接远程非易失性内存存储器。 SPDK_NVMF_DISCOVERY_NQN 的作用是让存储应用程序能够通过 SPDK 查询 NVMf 存储设备的名称,以便能够访问这些存储设备。通过查询 NVMf 存储设备名称,存储应用程序可以获取必要的信息,例如存储设备的IP地址、端口号、名称等,以便能
recommend-type

JSBSim Reference Manual

JSBSim参考手册,其中包含JSBSim简介,JSBSim配置文件xml的编写语法,编程手册以及一些应用实例等。其中有部分内容还没有写完,估计有生之年很难看到完整版了,但是内容还是很有参考价值的。