import torch from djitellopy import Tello import cv2 import numpy as np import models from models import yolo def get_model(): # 假设 'yolov5s.yaml' 是 yolov5s 模型的定义文件的路径 model = models.yolo.Model('models/yolov5s.yaml') checkpoint = torch.load('weights/yolov5s.pt') model.load_state_dict(checkpoint['model']) model.eval() return model def preprocess_frame(img): img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) img = cv2.resize(img, (640, 640)) # 将图像大小调整为模型的输入大小 img = img / 255.0 # 将像素值归一化到 [0, 1] img = np.transpose(img, (2, 0, 1)) # 将图像从 HWC 格式转换为 CHW 格式 img = torch.from_numpy(img).float() # 将 Numpy 数组转换为 PyTorch 张量 img = img.unsqueeze(0) # 增加一个批量维度 return img def process_frame(model, img): img_preprocessed = preprocess_frame(img) results = model(img_preprocessed) # 处理模型的输出 results = results[0].detach().cpu().numpy() # 将结果从 GPU 移动到 CPU 并转换为 Numpy 数组 for x1, y1, x2, y2, conf, cls in results: # 将坐标从 [0, 1] 范围转换回图像的像素坐标 x1, y1, x2, y2 = x1 * img.shape[1], y1 * img.shape[0], x2 * img.shape[1], y2 * img.shape[0] # 在图像上画出边界框 cv2.rectangle(img, (int(x1), int(y1)), (int(x2), int(y2)), (255, 0, 0), 2) # 在边界框旁边显示类别和置信度 cv2.putText(img, f'{int(cls)} {conf:.2f}', (int(x1), int(y1) - 5), cv2.FONT_HERSHEY_SIMPLEX, 0.5, (255, 0, 0), 2) # 显示图像 cv2.imshow('Tello with YOLOv5', img) return cv2.waitKey(1) def main(): tello = Tello() tello.connect() tello.streamon() frame_read = tello.get_frame_read() model = get_model() frame_skip = 2 # 每两帧处理一次 counter = 0 while True: if counter % frame_skip == 0: # 只处理每两帧中的一帧 img = frame_read.frame process_frame(model, img) counter += 1 cv2.destroyAllWindows() if __name__ == '__main__': main() 修改这段代码
时间: 2023-11-13 22:02:23 浏览: 163
Pycharm中import torch报错的快速解决方法
import torch
from djitellopy import Tello
import cv2
import numpy as np
from models import yolo
def get_model():
# 假设 'yolov5s.yaml' 是 yolov5s 模型的定义文件的路径
model = yolo.Model('models/yolov5s.yaml')
checkpoint = torch.load('weights/yolov5s.pt')
model.load_state_dict(checkpoint['model'])
model.eval()
return model
def preprocess_frame(img):
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
img = cv2.resize(img, (640, 640)) # 将图像大小调整为模型的输入大小
img = img / 255.0 # 将像素值归一化到 [0, 1]
img = np.transpose(img, (2, 0, 1)) # 将图像从 HWC 格式转换为 CHW 格式
img = torch.from_numpy(img).float() # 将 Numpy 数组转换为 PyTorch 张量
img = img.unsqueeze(0) # 增加一个批量维度
return img
def process_frame(model, img):
img_preprocessed = preprocess_frame(img)
results = model(img_preprocessed)
# 处理模型的输出
results = results[0].detach().cpu().numpy() # 将结果从 GPU 移动到 CPU 并转换为 Numpy 数组
for x1, y1, x2, y2, conf, cls in results:
# 将坐标从 [0, 1] 范围转换回图像的像素坐标
x1, y1, x2, y2 = int(x1 * img.shape[3]), int(y1 * img.shape[2]), int(x2 * img.shape[3]), int(y2 * img.shape[2])
# 在图像上画出边界框
cv2.rectangle(img, (x1, y1), (x2, y2), (255, 0, 0), 2)
# 在边界框旁边显示类别和置信度
cv2.putText(img, f'{int(cls)} {conf:.2f}', (x1, y1 - 5), cv2.FONT_HERSHEY_SIMPLEX, 0.5, (255, 0, 0), 2)
# 显示图像
cv2.imshow('Tello with YOLOv5', img)
return cv2.waitKey(1)
def main():
tello = Tello()
tello.connect()
tello.streamon()
frame_read = tello.get_frame_read()
model = get_model()
frame_skip = 1 # 每一帧处理一次
counter = 0
while True:
img = frame_read.frame
if counter % frame_skip == 0: # 只处理每一帧
process_frame(model, img)
counter += 1
if cv2.waitKey(1) & 0xFF == ord('q'): # 按下 'q' 键退出
break
cv2.destroyAllWindows()
if __name__ == '__main__':
main()
阅读全文