仿照yolox里面的mosaic数据增强来写一个python脚本
时间: 2023-12-05 15:03:04 浏览: 85
好的,下面是一个简单的 Python 脚本,实现了 Mosaic 数据增强:
```python
import cv2
import numpy as np
import random
def mosaic(images, labels, img_size, grid_size):
"""
Mosaic 数据增强,将多张图片拼接在一起形成一张大图,并修改标签信息。
Args:
- images: 图片列表,每个元素为一张图片的路径或者 numpy 数组。
- labels: 标签列表,每个元素为一张图片对应的标签,格式为 [[x1, y1, x2, y2, class_id], ...]。
- img_size: 输出图片的尺寸,格式为 (width, height)。
- grid_size: 拼接的网格大小,格式为 (cols, rows)。
Returns:
- output_image: 拼接后的大图,格式为 numpy 数组。
- output_labels: 修改后的标签列表,格式同输入的 labels。
"""
assert len(images) == len(labels), "The number of images and labels must be the same."
assert grid_size[0] * grid_size[1] == len(images), "The number of images must be equal to the number of grid cells."
# 读取图片并 resize 到指定尺寸
input_images = []
for image in images:
if isinstance(image, str):
image = cv2.imread(image)
image = cv2.resize(image, img_size)
input_images.append(image)
# 随机选择一张图片作为背景,并初始化输出图片和标签
bg_index = random.randint(0, len(input_images) - 1)
output_image = np.copy(input_images[bg_index])
output_labels = []
# 遍历网格中每个单元格,分别将对应的图片拼接到输出图片中
cell_size = (img_size[0] // grid_size[0], img_size[1] // grid_size[1])
for row in range(grid_size[1]):
for col in range(grid_size[0]):
index = row * grid_size[0] + col
image = input_images[index]
# 计算拼接位置和大小
x1 = col * cell_size[0]
y1 = row * cell_size[1]
x2 = (col + 1) * cell_size[0]
y2 = (row + 1) * cell_size[1]
# 将当前图片拼接到输出图片中
output_image[y1:y2, x1:x2, :] = image
# 修改标签信息
for label in labels[index]:
x1_, y1_, x2_, y2_, class_id = label
x1_ = int(x1_ * cell_size[0] + x1)
y1_ = int(y1_ * cell_size[1] + y1)
x2_ = int(x2_ * cell_size[0] + x1)
y2_ = int(y2_ * cell_size[1] + y1)
output_labels.append([x1_, y1_, x2_, y2_, class_id])
return output_image, output_labels
```
使用方法如下:
```python
# 读取图片和标签
images = ["image1.jpg", "image2.jpg", "image3.jpg", "image4.jpg"]
labels = [
[[0.1, 0.2, 0.3, 0.4, 0], [0.5, 0.6, 0.7, 0.8, 1]],
[[0.2, 0.3, 0.4, 0.5, 1]],
[[0.3, 0.4, 0.5, 0.6, 2]],
[[0.4, 0.5, 0.6, 0.7, 0], [0.8, 0.9, 1.0, 1.1, 2]],
]
# 进行 Mosaic 数据增强
output_image, output_labels = mosaic(images, labels, (800, 800), (2, 2))
# 显示输出图片和标签
for label in output_labels:
x1, y1, x2, y2, class_id = label
cv2.rectangle(output_image, (x1, y1), (x2, y2), (0, 255, 0), 2)
cv2.imshow("output image", output_image)
cv2.waitKey(0)
```
阅读全文