def adjustData(img,mask,flag_multi_class,num_class): if(flag_multi_class): img = img / 255 mask = mask[:,:,:,0] if(len(mask.shape) == 4) else mask[:,:,0] new_mask = np.zeros(mask.shape + (num_class,)) for i in range(num_class): #for one pixel in the image, find the class in mask and convert it into one-hot vector #index = np.where(mask == i) #index_mask = (index[0],index[1],index[2],np.zeros(len(index[0]),dtype = np.int64) + i) if (len(mask.shape) == 4) else (index[0],index[1],np.zeros(len(index[0]),dtype = np.int64) + i) #new_mask[index_mask] = 1 new_mask[mask == i,i] = 1 new_mask = np.reshape(new_mask,(new_mask.shape[0],new_mask.shape[1]*new_mask.shape[2],new_mask.shape[3])) if flag_multi_class else np.reshape(new_mask,(new_mask.shape[0]*new_mask.shape[1],new_mask.shape[2])) mask = new_mask elif(np.max(img) > 1): img = img / 255 mask = mask /255 mask[mask > 0.5] = 1 mask[mask <= 0.5] = 0 return (img,mask)这个函数处理的是几维图片
时间: 2024-01-26 12:03:41 浏览: 105
Python中if __name__ == '__main__'作用解析
这个函数处理的是二维或三维图片,取决于输入的img和mask的维度。如果img和mask是二维的,则代表是灰度图像;如果是三维的,则代表是彩色图像。函数中的代码可以将多分类的标签数据转换成适合训练的形式。如果flag_multi_class参数为True,则将mask转换成一个新的张量,该张量的最后一维是类别数;否则,将mask压缩成一维张量。
阅读全文