基于opencv的水果识别的代码
时间: 2023-10-11 22:05:31 浏览: 235
以下是一个基于OpenCV的简单水果识别代码示例,使用了SVM算法:
```python
import cv2
import numpy as np
import matplotlib.pyplot as plt
# 加载数据集
dataset = cv2.imread("fruits.jpg")
labels = np.array(["apple", "banana", "orange", "apple", "banana", "orange"])
# 特征提取
def extract_features(image):
gray = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY)
blur = cv2.GaussianBlur(gray, (5, 5), 0)
_, thresh = cv2.threshold(blur, 0, 255, cv2.THRESH_BINARY_INV + cv2.THRESH_OTSU)
cnts, _ = cv2.findContours(thresh, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
cnt = cnts[0]
x, y, w, h = cv2.boundingRect(cnt)
roi = image[y:y + h, x:x + w]
roi = cv2.resize(roi, (50, 50))
return roi.flatten()
# 特征提取和标签编码
data = []
for i in range(len(dataset)):
features = extract_features(dataset[i])
data.append(features)
data = np.array(data)
label_dict = {"apple": 0, "banana": 1, "orange": 2}
labels = np.array([label_dict[i] for i in labels])
# 模型训练
model = cv2.ml.SVM_create()
model.setType(cv2.ml.SVM_C_SVC)
model.setKernel(cv2.ml.SVM_LINEAR)
model.setTermCriteria((cv2.TERM_CRITERIA_EPS + cv2.TERM_CRITERIA_COUNT, 10000, 1e-6))
model.train(data, cv2.ml.ROW_SAMPLE, labels)
# 预测图像
test_image = cv2.imread("test_fruit.jpg")
test_features = extract_features(test_image)
test_data = np.array(test_features).reshape(1, -1)
_, result = model.predict(test_data)
print("预测结果为:", list(label_dict.keys())[list(label_dict.values()).index(int(result))])
# 显示图像
plt.imshow(cv2.cvtColor(test_image, cv2.COLOR_BGR2RGB))
plt.show()
```
需要注意的是,这只是一个简单的示例,实际的水果识别系统需要更多的数据和更复杂的算法以提高准确性。同时,这里的特征提取方法也可进行优化。
阅读全文