使用pytorch设计构建图像识别网络,它的训练数据包含15种,每种25个,合计3750个样本,图像大小为28*28的灰度图像。以下是部分参考代码import numpy as np import cv2 import os dirPath = "./testDataOrg/" totNum=0 if os.path.isdir(dirPath): fileList = os.listdir(dirPath) for f in fileList: nn1 = f.rfind(',')+1 nn2 = f.find('}',nn1) label = int(f[nn1:nn2])#图像标签 1到11对应”零“到”十“,12到15对应”百“,”千“,”万“,”亿“。 print(f, label) totNum += 1 fi = dirPath+f img = cv2.imread(fi)#图像数据 cv2.imshow("img", cv2.resize(img, (300, 300))) cv2.waitKey(0) print("图像数量:",totNum)
时间: 2024-03-30 21:36:54 浏览: 42
将使用Pytorch框架实现经典的图像分类网络、目标检测网络、图像分割网络,图像生成网络等,并会持续更新!!!.zip
这段代码是用来读取图像数据和标签的。首先,它将指定目录下的所有文件名读入到fileList中。接着,对于每个文件名f,通过查找字符串中的特定字符来获取其标签label。图像的标签是从1到15的整数,其中1到11表示数字0到9,12到15表示“百”,“千”,“万”和“亿”。然后,使用OpenCV库的imread函数读取图像数据。最后,通过imshow函数将图像显示出来,并等待用户按下任意键关闭图像窗口。最后,打印出总共读入的图像数量。这段代码只是读取数据的代码,还需要进一步设计模型和进行训练才能实现图像识别的功能。
阅读全文