a=np.random.randint(90,96,size=(20,2)).astype(np.float32) b=np.random.randint(95,101,size=(20,2)).astype(np.float32) data=np.vstack((a,b)) data=np.array(data,dtype="float32") alabel=np.zeros((20,1)) blabel=np.ones((20,1)) label=np.vstack((alabel,blabel)) label=np.array(label,dtype="float32") svm=cv.ml.SVM_create() svm.setType(cv.ml.SVM_C_SVC) svm.setKernel(cv.ml.SVM_LINEAR) svm.setC(0.01) result=svm.train(data,cv.ml.ROW_SAMPLE,label) test=np.vstack(([93.92],[96.97])) test=np.array(test,dtype=np.float32) p1,p2=svm.predict(test) plt.scatter(a[:,0],a[:,1],80,"g","o") plt.scatter(b[:,0],b[:,1],80,"b","s") plt.scatter(test[:,0],test[:,1],80,"r","*") plt.show() print(p1,p2)代码哪里需要修改
时间: 2023-12-20 14:05:52 浏览: 185
feature_extraction.py
这个代码看起来没有明显的错误,但是有一些地方可以进行修改和改进。
1. 在定义变量时,建议使用小写字母,按照PEP8规范进行命名,例如将“alabel”和“blabel”改为“a_label”和“b_label”。
2. 在定义测试数据时,可以使用np.array直接创建数组,而无需使用vstack和array组合的方式。
3. 在使用plt.scatter函数绘制散点图时,第三个参数应该是标记的大小,而不是颜色。
4. 在使用svm.predict函数时,应该将测试数据作为二维数组传递,例如:svm.predict(np.array([[93.92, 96.97]], dtype=np.float32))。
下面是修改后的代码:
```python
import numpy as np
import cv2 as cv
import matplotlib.pyplot as plt
a = np.random.randint(90,96,size=(20,2)).astype(np.float32)
b = np.random.randint(95,101,size=(20,2)).astype(np.float32)
data = np.vstack((a,b))
data = np.array(data,dtype=np.float32)
a_label = np.zeros((20,1))
b_label = np.ones((20,1))
label = np.vstack((a_label,b_label))
label = np.array(label,dtype=np.float32)
svm = cv.ml.SVM_create()
svm.setType(cv.ml.SVM_C_SVC)
svm.setKernel(cv.ml.SVM_LINEAR)
svm.setC(0.01)
result = svm.train(data,cv.ml.ROW_SAMPLE,label)
test = np.array([[93.92, 96.97]], dtype=np.float32)
p1, p2 = svm.predict(test)
plt.scatter(a[:,0],a[:,1],s=80,c="g",marker="o")
plt.scatter(b[:,0],b[:,1],s=80,c="b",marker="s")
plt.scatter(test[:,0],test[:,1],s=80,c="r",marker="*")
plt.show()
print(p1, p2)
```
希望这些修改可以帮助您更好地运行代码!
阅读全文