使用使用Keras 的的ImageDataGenerator类实现批量数据增强类实现批量数据增强
今天使用了Keras 的ImageDataGenerator类,发现真是小白的神器。我们在进行机器学习的时候,常常为自己找不到相关的数据集而自己搭建一个数据集。那么,相关的问题就
是数据样本不够大,之后的机器学习就很有能造成过拟合问题,神经网络找不到抽象的特征等。究根结底还是样本数量不够。那我又不可能拿着相机一个一个去拍啊…….
总之就想要更多的数据集呗。
有关于ImageDataGenerator的相关信息,这篇博客已经写得非常好了–>keras的图像预处理全攻略(二)—— ImageDataGenerator 类, 有关于ImageDataGenerator类的用法什
么的可以查它
下面是全部代码(Win10 python3.6.1 pycharm)
import os
import shutil
from keras.preprocessing.image import ImageDataGenerator, load_img, img_to_array, array_to_img
class Augmentation(object):
def __init__(self,img_type="png"):
self.datagen=ImageDataGenerator(
rotation_range=10,
width_shift_range=0.05,
height_shift_range=0.05,
shear_range=0.05,
zoom_range=0.1,
horizontal_flip=True,
fill_mode='nearest')
def augmentation(self, picname, savedir, pre, img_num):
print("运行 Augmentation")
# Start augmentation.....
img_t = load_img(picname) # 读入train
x_t = img_to_array(img_t) # 转换成矩阵
img = x_t
img = img.reshape((1,) + img.shape)
print(img.shape)
if not os.path.lexists(savedir):
os.mkdir(savedir)
print("running %d doAugmenttaion" % 0)
self.do_augmentate(img, savedir, pre, imgnum=img_num) # 数据增强
def do_augmentate(self, img, save_to_dir, save_prefix, batch_size=1, save_format='jpg', imgnum=10):
# augmentate one image
datagen = self.datagen
i = 0
for _ in datagen.flow(
img,
batch_size=batch_size,
save_to_dir=save_to_dir,
save_prefix=save_prefix,
save_format=save_format):
i += 1
if i >= imgnum:
break
def batchgenerate(self, picsrc, savedir, imgnum):
picList = os.listdir(picsrc)
len1 = len(picList)
subfile = "SUB"
if not os.path.lexists("SUB"):
os.mkdir("SUB")
for picname in picList:
file_src = picsrc + '/' + picname
pre = picname.split('.')[0] self.augmentation(file_src, "SUB", pre, imgnum)
picListsub = os.listdir("SUB")
num = [] for i in range(len1):
num.append(0)
print("图片个数", len(num))
loc = 0
print("转换图片列表",picList)
for picname in picListsub:
file_src = "SUB" + '/' + picname
pre = picname.split('_')[0] + '_' + picname.split('_')[1]+'.jpg'
location = picList.index(pre)
all = picname.split('_')[0] + '_' + picname.split('_')[1]+"_"+str(num[location])+'.jpg'
num[location] = num[location]+1
file_dst = savedir + '/' + all
shutil.copyfile(file_src, file_dst)
os.remove(file_src)
if __name__=="__main__":
picsrc = "C:/Users/ddd/Desktop/project/dataSets/counterpart/aeroplane_0"
savedir = "imageExpand"
aug=Augmentation()
aug.batchgenerate(picsrc, savedir, 5)
第一段:
def __init__(self,img_type="png"):
self.datagen=ImageDataGenerator(
rotation_range=10, # 图片可能在(-10,10)度内旋转
width_shift_range=0.05, # 图片可能在左右比例(-0.05,0.05)(百分比)内水平移动
height_shift_range=0.05, # 图片可能在上下比例(-0.05,0.05)(百分比)内水平移动
shear_range=0.05, # 图片可能以比例(-0.05,0.05)内错切变换
zoom_range=0.1, # 图片可能以比例(-0.1,0.1)内缩放
评论0