yolov5数据集扩充python完整代码
时间: 2023-08-23 15:05:47 浏览: 52
以下是一个使用 Albumentations 库进行数据增强的 YOLOv5 数据集扩充 Python 代码。你需要先安装 Albumentations 库。
```python
import os
import cv2
import numpy as np
from albumentations import (
HorizontalFlip, VerticalFlip, Rotate, Blur, RandomBrightnessContrast,
RandomGamma, RandomResizedCrop, Compose
)
# 设置数据增强的参数
transform = Compose([
HorizontalFlip(p=0.5),
VerticalFlip(p=0.5),
Rotate(limit=10, p=0.5),
Blur(blur_limit=3, p=0.5),
RandomBrightnessContrast(brightness_limit=(-0.1, 0.1), contrast_limit=(-0.1, 0.1), p=0.5),
RandomGamma(gamma_limit=(80, 120), p=0.5),
RandomResizedCrop(height=416, width=416, scale=(0.8, 1.0), ratio=(0.75, 1.3333333333333333), p=0.5)
])
# 数据集路径
data_path = '/path/to/data'
# 输出路径
output_path = '/path/to/output'
# 遍历数据集,对每张图片进行数据增强并保存
for root, dirs, files in os.walk(data_path):
for file in files:
if file.endswith('.jpg'):
img_path = os.path.join(root, file)
img = cv2.imread(img_path)
# 读取对应的标注文件
label_path = os.path.join(root, file.replace('.jpg', '.txt'))
if os.path.exists(label_path):
with open(label_path, 'r') as f:
labels = f.read().strip().split('\n')
# 解析标注信息
bboxes = []
classes = []
for label in labels:
parts = label.split(' ')
x, y, w, h = map(float, parts[1:])
x1, y1, x2, y2 = int((x - w / 2) * img.shape[1]), int((y - h / 2) * img.shape[0]), int((x + w / 2) * img.shape[1]), int((y + h / 2) * img.shape[0])
bboxes.append([x1, y1, x2, y2])
classes.append(int(parts[0]))
# 进行数据增强
transformed = transform(image=img, bboxes=bboxes, class_ids=classes)
# 保存增强后的图片
new_img_path = os.path.join(output_path, file)
cv2.imwrite(new_img_path, transformed['image'])
# 保存增强后的标注信息
with open(new_img_path.replace('.jpg', '.txt'), 'w') as f:
for i in range(len(transformed['bboxes'])):
x1, y1, x2, y2 = transformed['bboxes'][i]
x, y = (x1 + x2) / 2 / transformed['image'].shape[1], (y1 + y2) / 2 / transformed['image'].shape[0]
w, h = (x2 - x1) / transformed['image'].shape[1], (y2 - y1) / transformed['image'].shape[0]
f.write(f"{transformed['class_ids'][i]} {x:.6f} {y:.6f} {w:.6f} {h:.6f}\n")
```
这个代码会对指定的数据集进行遍历,对每张图片进行数据增强并保存增强后的图片和标注信息。注意在保存标注信息时需要将坐标位置和宽高信息转换为相对于增强后的图片大小的比例。