再具体一点能针对某一类别进行数据增强吗
时间: 2023-10-02 22:08:29 浏览: 66
当需要针对特定类别进行数据增强时,你可以在代码中加入条件判断来控制增强的行为。下面是一个示例代码,针对标注中的某一特定类别进行数据增强:
```python
import os
import cv2
import numpy as np
import imgaug.augmenters as iaa
from labelme import utils
def data_augmentation(image_dir, label_dir, save_dir, target_class):
# 创建保存增强后数据的文件夹
if not os.path.exists(save_dir):
os.makedirs(save_dir)
# 获取所有图片文件
image_files = os.listdir(image_dir)
for image_file in image_files:
# 读取原始图片
image_path = os.path.join(image_dir, image_file)
image = cv2.imread(image_path)
# 读取对应的标注文件
label_file = os.path.splitext(image_file)[0] + ".json"
label_path = os.path.join(label_dir, label_file)
with open(label_path, "r") as f:
label_data = utils.json_load(f)
label_shapes = label_data["shapes"]
# 数据增强
seq = iaa.Sequential([
iaa.Fliplr(0.5), # 水平翻转概率为50%
iaa.GaussianBlur(sigma=(0, 3.0)), # 高斯模糊
iaa.Affine(rotate=(-10, 10)), # 旋转角度范围为-10到10度
iaa.Multiply((0.8, 1.2)) # 亮度乘法器
])
augmented_image = seq.augment_image(image)
augmented_label_shapes = []
for shape in label_shapes:
class_name = shape["label"]
if class_name == target_class:
# 针对目标类别进行增强
augmented_shape = seq.augment_keypoints([ia.Keypoint(x=point[0], y=point[1]) for point in shape["points"]])
augmented_label_shapes.append({"label": class_name, "points": [[kp.x, kp.y] for kp in augmented_shape.keypoints]})
else:
# 其他类别保持不变
augmented_label_shapes.append(shape)
# 保存增强后的图片和标注
augmented_image_path = os.path.join(save_dir, image_file)
cv2.imwrite(augmented_image_path, augmented_image)
augmented_label_data = label_data.copy()
augmented_label_data["shapes"] = augmented_label_shapes
augmented_label_file = os.path.join(save_dir, label_file)
with open(augmented_label_file, "w") as f:
utils.json_dump(augmented_label_data, f)
print("Saved augmented data:", augmented_image_path, augmented_label_file)
# 示例调用
image_dir = "path_to_original_images"
label_dir = "path_to_label_files"
save_dir = "path_to_save_augmented_data"
target_class = "person" # 指定目标类别
data_augmentation(image_dir, label_dir, save_dir, target_class)
```
在示例代码中,我们通过 `target_class` 参数指定了目标类别。对于符合目标类别的标注形状,我们使用 `iaa.augment_keypoints` 函数对关键点进行增强,对其他类别的标注形状保持不变。
请注意,`target_class` 参数需要与标注文件中的类别名称完全匹配。如果你的标注文件中的类别名称是中文或其他非英文字符,请确保在代码中使用正确的字符编码。
希望这个示例能够满足你的需求!如果有任何问题,请随时提问。