yolov5利用旋转,裁剪扩充数据集代码
时间: 2023-08-23 08:05:03 浏览: 102
以下是使用Python和OpenCV实现YoloV5数据增强的代码,包括旋转和裁剪等操作:
```python
import cv2
import numpy as np
import os
# 旋转图像
def rotate_image(image, angle):
# 获取图像的高度和宽度
(h, w) = image.shape[:2]
# 找到图像的中心点
center = (w / 2, h / 2)
# 通过旋转矩阵进行旋转
M = cv2.getRotationMatrix2D(center, angle, 1.0)
rotated = cv2.warpAffine(image, M, (w, h))
# 返回旋转后的图像
return rotated
# 裁剪图像
def crop_image(image, x, y, w, h):
cropped = image[y:y+h, x:x+w]
return cropped
# 扩充数据集
def augment_data(image_path, label_path, output_path):
# 读取图像和标签数据
image = cv2.imread(image_path)
with open(label_path, 'r') as f:
labels = f.readlines()
# 对图像进行旋转和裁剪操作
for i in range(5):
# 旋转图像
angle = np.random.randint(-10, 10)
rotated = rotate_image(image, angle)
# 裁剪图像
x = np.random.randint(0, 100)
y = np.random.randint(0, 100)
w = np.random.randint(100, 300)
h = np.random.randint(100, 300)
cropped = crop_image(rotated, x, y, w, h)
# 保存扩充后的图像和标签
filename = os.path.basename(image_path)
image_name, ext = os.path.splitext(filename)
output_image_path = os.path.join(output_path, f"{image_name}_{i}{ext}")
output_label_path = os.path.join(output_path, f"{image_name}_{i}.txt")
cv2.imwrite(output_image_path, cropped)
with open(output_label_path, 'w') as f:
for label in labels:
f.write(label)
```
这个代码实现了对输入图像的旋转和裁剪操作,以扩充数据集。对于每个输入图像,将生成5个扩充后的图像,并将它们保存到指定的输出目录中。这个代码可以在YoloV5的训练过程中使用,以提高模型的准确性。
阅读全文