knn算法java实现鸢尾花
时间: 2023-11-18 08:51:45 浏览: 129
好的,对于knn算法的Java实现,我们可以按照以下步骤进行:
1. 读取数据集
2. 对数据集进行预处理,例如数据归一化、特征选择等
3. 将数据集划分为训练集和测试集
4. 针对测试集中的每个样本,计算其与训练集中所有样本的距离
5. 对距离计算结果进行排序,选择前k个距离最近的样本
6. 根据这k个样本的类别,使用投票法确定测试样本所属的类别
7. 计算模型的准确率
相关代码实现如下:
```
// 1. 读取数据集
List<List<Double>> dataSet = new ArrayList<>();
List<Integer> labels = new ArrayList<>();
try {
BufferedReader reader = new BufferedReader(new FileReader("iris.data")); // 数据集文件路径
String line;
while ((line = reader.readLine()) != null) {
String[] fields = line.split(",");
List<Double> dataRow = new ArrayList<>();
for (int i = 0; i < fields.length - 1; i++) {
dataRow.add(Double.parseDouble(fields[i]));
}
dataSet.add(dataRow);
labels.add(Integer.parseInt(fields[fields.length - 1]));
}
reader.close();
} catch (IOException e) {
e.printStackTrace();
}
// 2. 数据预处理(省略)
// 3. 将数据集划分为训练集和测试集
int n = dataSet.size();
int m = n * 7 / 10; // 训练集占比70%
List<List<Double>> trainData = new ArrayList<>();
List<Integer> trainLabels = new ArrayList<>();
List<List<Double>> testData = new ArrayList<>();
List<Integer> testLabels = new ArrayList<>();
List<Integer> indexList = new ArrayList<>();
for (int i = 0; i < n; i++) {
indexList.add(i);
}
Collections.shuffle(indexList); // 打乱数据集顺序
for (int i = 0; i < m; i++) {
trainData.add(dataSet.get(indexList.get(i)));
trainLabels.add(labels.get(indexList.get(i)));
}
for (int i = m; i < n; i++) {
testData.add(dataSet.get(indexList.get(i)));
testLabels.add(labels.get(indexList.get(i)));
}
// 4. 计算距离和排序
int k = 5; // k值
int errorCount = 0;
for (int i = 0; i < testData.size(); i++) {
List<Double> testRow = testData.get(i);
PriorityQueue<Pair<Double, Integer>> pq = new PriorityQueue<>((a, b) -> -Double.compare(a.getKey(), b.getKey())); // 大根堆
for (int j = 0; j < trainData.size(); j++) {
List<Double> trainRow = trainData.get(j);
double dist = 0;
for (int c = 0; c < testRow.size(); c++) {
dist += Math.pow(testRow.get(c) - trainRow.get(c), 2);
}
pq.offer(new Pair<>(dist, trainLabels.get(j)));
if (pq.size() > k) {
pq.poll();
}
}
// 6. 投票法
int[] count = new int[3];
Arrays.fill(count, 0);
for (Pair<Double, Integer> pair : pq) {
count[pair.getValue() - 1]++;
}
int predict = 1;
for (int j = 1; j < 3; j++) {
if (count[j] > count[predict - 1]) {
predict = j + 1;
}
}
if (predict != testLabels.get(i)) {
errorCount++;
}
}
// 7. 计算准确率
double accuracy = (double) (testData.size() - errorCount) / testData.size();
System.out.println("Accuracy: " + accuracy);
```
其中,我们假设数据集的类别只有3种,分别为1、2、3。
阅读全文