def save_conf_to_yaml(self): """dict保存为yaml""" with open(self.yaml, 'w') as file: file.write(yaml.dump(self.read_froth_conf(), allow_unicode=True, default_flow_style=False, sort_keys=False))
时间: 2024-04-02 07:35:28 浏览: 137
这是一个方法,它将一个字典对象保存为yaml格式的文件。方法名为 save_conf_to_yaml(),它接受一个self参数,表示该方法是一个类的成员方法。该方法打开一个文件,将字典对象转换为yaml格式并写入该文件中。其中,self.yaml表示要保存的yaml文件路径,self.read_froth_conf()返回一个字典对象,yaml.dump()方法将字典对象转换为yaml格式的字符串,最后通过file.write()方法将字符串写入文件中。其中的一些参数包括allow_unicode表示是否允许写入unicode字符,default_flow_style表示是否使用默认的yaml流样式,sort_keys表示是否按字典键排序。
相关问题
优化以下代码 def yml() -> dict: """读取配置文件""" try: with open(os.path.abspath('./conf/config.yml'), 'r', encoding='utf-8') as f_yml: # r = yaml.load(f_yml, Loader=yaml.FullLoader) r = yaml.safe_load(f_yml) return r except yaml.YAMLError as Y: logging.error("config.yml文件内容配置有误!\n%s" % Y)
优化代码如下:
```python
def yml() -> dict:
"""读取配置文件"""
try:
with open(os.path.abspath('./conf/config.yml'), 'r', encoding='utf-8') as f_yml:
return yaml.safe_load(f_yml)
except yaml.YAMLError as exc:
logging.error("config.yml文件内容配置有误!\n%s", exc)
```
主要优化点:
1. 删除了注释,让代码更加简洁易读;
2. 使用 `yaml.safe_load()` 代替 `yaml.load()`,更加安全;
3. 修改了日志输出方式,使用了更加简洁的写法,并且将 `%` 替换为了 `,`,避免了字符串中包含 `%` 导致的错误。
import torch from djitellopy import Tello import cv2 import numpy as np import models from models import yolo def get_model(): # 假设 'yolov5s.yaml' 是 yolov5s 模型的定义文件的路径 model = models.yolo.Model('models/yolov5s.yaml') checkpoint = torch.load('weights/yolov5s.pt') model.load_state_dict(checkpoint['model']) model.eval() return model def preprocess_frame(img): img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) img = cv2.resize(img, (640, 640)) # 将图像大小调整为模型的输入大小 img = img / 255.0 # 将像素值归一化到 [0, 1] img = np.transpose(img, (2, 0, 1)) # 将图像从 HWC 格式转换为 CHW 格式 img = torch.from_numpy(img).float() # 将 Numpy 数组转换为 PyTorch 张量 img = img.unsqueeze(0) # 增加一个批量维度 return img def process_frame(model, img): img_preprocessed = preprocess_frame(img) results = model(img_preprocessed) # 处理模型的输出 results = results[0].detach().cpu().numpy() # 将结果从 GPU 移动到 CPU 并转换为 Numpy 数组 for x1, y1, x2, y2, conf, cls in results: # 将坐标从 [0, 1] 范围转换回图像的像素坐标 x1, y1, x2, y2 = x1 * img.shape[1], y1 * img.shape[0], x2 * img.shape[1], y2 * img.shape[0] # 在图像上画出边界框 cv2.rectangle(img, (int(x1), int(y1)), (int(x2), int(y2)), (255, 0, 0), 2) # 在边界框旁边显示类别和置信度 cv2.putText(img, f'{int(cls)} {conf:.2f}', (int(x1), int(y1) - 5), cv2.FONT_HERSHEY_SIMPLEX, 0.5, (255, 0, 0), 2) # 显示图像 cv2.imshow('Tello with YOLOv5', img) return cv2.waitKey(1) def main(): tello = Tello() tello.connect() tello.streamon() frame_read = tello.get_frame_read() model = get_model() frame_skip = 2 # 每两帧处理一次 counter = 0 while True: if counter % frame_skip == 0: # 只处理每两帧中的一帧 img = frame_read.frame process_frame(model, img) counter += 1 cv2.destroyAllWindows() if __name__ == '__main__': main() 修改这段代码
import torch
from djitellopy import Tello
import cv2
import numpy as np
from models import yolo
def get_model():
# 假设 'yolov5s.yaml' 是 yolov5s 模型的定义文件的路径
model = yolo.Model('models/yolov5s.yaml')
checkpoint = torch.load('weights/yolov5s.pt')
model.load_state_dict(checkpoint['model'])
model.eval()
return model
def preprocess_frame(img):
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
img = cv2.resize(img, (640, 640)) # 将图像大小调整为模型的输入大小
img = img / 255.0 # 将像素值归一化到 [0, 1]
img = np.transpose(img, (2, 0, 1)) # 将图像从 HWC 格式转换为 CHW 格式
img = torch.from_numpy(img).float() # 将 Numpy 数组转换为 PyTorch 张量
img = img.unsqueeze(0) # 增加一个批量维度
return img
def process_frame(model, img):
img_preprocessed = preprocess_frame(img)
results = model(img_preprocessed)
# 处理模型的输出
results = results[0].detach().cpu().numpy() # 将结果从 GPU 移动到 CPU 并转换为 Numpy 数组
for x1, y1, x2, y2, conf, cls in results:
# 将坐标从 [0, 1] 范围转换回图像的像素坐标
x1, y1, x2, y2 = int(x1 * img.shape[3]), int(y1 * img.shape[2]), int(x2 * img.shape[3]), int(y2 * img.shape[2])
# 在图像上画出边界框
cv2.rectangle(img, (x1, y1), (x2, y2), (255, 0, 0), 2)
# 在边界框旁边显示类别和置信度
cv2.putText(img, f'{int(cls)} {conf:.2f}', (x1, y1 - 5), cv2.FONT_HERSHEY_SIMPLEX, 0.5, (255, 0, 0), 2)
# 显示图像
cv2.imshow('Tello with YOLOv5', img)
return cv2.waitKey(1)
def main():
tello = Tello()
tello.connect()
tello.streamon()
frame_read = tello.get_frame_read()
model = get_model()
frame_skip = 1 # 每一帧处理一次
counter = 0
while True:
img = frame_read.frame
if counter % frame_skip == 0: # 只处理每一帧
process_frame(model, img)
counter += 1
if cv2.waitKey(1) & 0xFF == ord('q'): # 按下 'q' 键退出
break
cv2.destroyAllWindows()
if __name__ == '__main__':
main()
阅读全文