import numpy as np import operator from os import listdir def classify0(inX, dataSet, labels, k): dataSetSize = dataSet.shape[0] diffMat = np.tile(inX, (dataSetSize, 1)) - dataSet sqDiffMat = diffMat2 sqDistances = sqDiffMat.sum(axis=1) distances = sqDistances0.5 sortedDistIndicies = distances.argsort() classCount = {} for i in range(k): voteIlabel = labels[sortedDistIndicies[i]] classCount[voteIlabel] = classCount.get(voteIlabel, 0) + 1 sortedClassCount = sorted(classCount.items(), key=operator.itemgetter(1), reverse=True) return sortedClassCount[0][0] def createDataSet(): group = np.array([[1.0, 5.0], [2.0, 2.0], [3.0, 1.0], [4.0, 6.0],[6.0, 8.0], [6.0, 5.0], [7.0, 9.0],[9.0,8.0]]) labels = ['A', 'A', 'B', 'B','A', 'B', 'A', 'B'] return group, labels def loadDataSet(fileName): #general function to parse tab -delimited floats dataMat = [] datalabel = [] #assume last column is target value fr = open(fileName) for line in fr.readlines(): curLine = line.strip().split('\t') fltLine = list(map(float,curLine[0:2])) #map all elements to float() fltLine1 = list(map(int,curLine[2])) dataMat.append(fltLine) datalabel.append(fltLine1) return dataMat,datalabel用knn.py程序求数据点(8,7)、(10,2)的标签类别
时间: 2024-01-28 08:04:53 浏览: 71
numpy linalg模块的具体使用方法
首先,需要调用 createDataSet 函数或 loadDataSet 函数读取数据集。由于这两个函数返回的数据集与问题无关,这里就不再赘述。
然后,对于每个需要预测的数据点,调用 classify0 函数进行预测。假设我们要预测的数据点列表为 testData,对应的标签列表为 testLabels,超参数 k 为 3,代码如下:
```
testData = np.array([[8.0, 7.0], [10.0, 2.0]])
testLabels = []
for inX in testData:
label = classify0(inX, dataSet, labels, 3)
testLabels.append(label)
print(testLabels)
```
运行以上代码,即可得到 testData 中每个数据点的预测标签类别。
阅读全文