YOLOv5转onnx GPU 代码
时间: 2023-08-24 12:08:17 浏览: 47
以下是将YOLOv5转换为ONNX并在GPU上运行的示例代码:
```python
import torch
import onnx
import onnxruntime
from models.experimental import attempt_load
from utils.general import non_max_suppression
# Load YOLOv5 model
weights = 'path/to/weights.pt'
device = torch.device('cuda:0') # GPU
model = attempt_load(weights, map_location=device) # load FP32 model
model.to(device).eval()
# Dummy input
batch_size = 1
input_size = (640, 640)
dummy_input = torch.randn(batch_size, 3, input_size[0], input_size[1], device=device)
# Convert to ONNX
input_names = ['input']
output_names = ['output']
onnx_path = 'path/to/model.onnx'
dynamic_axes = {'input': {0: 'batch_size'}, 'output': {0: 'batch_size'}}
torch.onnx.export(model, dummy_input, onnx_path, opset_version=11, input_names=input_names, output_names=output_names, dynamic_axes=dynamic_axes)
# Load ONNX model
sess_options = onnxruntime.SessionOptions()
sess_options.graph_optimization_level = onnxruntime.GraphOptimizationLevel.ORT_ENABLE_ALL
sess = onnxruntime.InferenceSession(onnx_path, sess_options=sess_options)
# Run inference on GPU
input_data = dummy_input.cpu().numpy() # convert to numpy array
outputs = sess.run(None, {input_names[0]: input_data}) # run inference
detections = non_max_suppression(torch.from_numpy(outputs[0]), conf_thres=0.25, iou_thres=0.45) # post-process detections
```
需要注意的是,这个示例代码假定您已经安装了YOLOv5源代码,并将其放在了Python模块中,可以在`from models.experimental import attempt_load`中导入。如果您没有安装YOLOv5源代码,您需要将其下载并放在正确的位置。