用java写一个最邻近算法
时间: 2024-01-23 21:03:48 浏览: 31
最邻近算法(Nearest Neighbor Algorithm)是一种简单的分类算法,它的基本思想是:如果一个样本在特征空间中的k个最相似(即特征空间中最邻近)的样本中的大多数属于某一个类别,则该样本也属于这个类别。
Java代码实现:
```java
import java.util.*;
public class KNN {
// 计算两个点之间的距离
private double calcDistance(double[] a, double[] b) {
double distance = 0;
for (int i = 0; i < a.length; i++) {
distance += Math.pow(a[i] - b[i], 2);
}
return Math.sqrt(distance);
}
// 计算前k个最近邻居的类别
private int getMajorityClass(int k, List<double[]> neighbors, List<Integer> classes) {
Map<Integer, Integer> classCount = new HashMap<>();
for (int i = 0; i < k; i++) {
int c = classes.get(i);
classCount.put(c, classCount.getOrDefault(c, 0) + 1);
}
int maxCount = 0;
int majorityClass = -1;
for (Map.Entry<Integer, Integer> entry : classCount.entrySet()) {
int count = entry.getValue();
if (count > maxCount) {
maxCount = count;
majorityClass = entry.getKey();
}
}
return majorityClass;
}
// 最邻近算法
public int knn(int k, double[] testPoint, List<double[]> trainData, List<Integer> trainClasses) {
List<double[]> neighbors = new ArrayList<>();
List<Integer> classes = new ArrayList<>();
for (int i = 0; i < trainData.size(); i++) {
double[] trainPoint = trainData.get(i);
double distance = calcDistance(trainPoint, testPoint);
if (neighbors.size() < k) {
neighbors.add(trainPoint);
classes.add(trainClasses.get(i));
} else {
int maxIndex = 0;
double maxDistance = 0;
for (int j = 0; j < neighbors.size(); j++) {
double[] neighbor = neighbors.get(j);
double d = calcDistance(neighbor, testPoint);
if (d > maxDistance) {
maxIndex = j;
maxDistance = d;
}
}
if (distance < maxDistance) {
neighbors.set(maxIndex, trainPoint);
classes.set(maxIndex, trainClasses.get(i));
}
}
}
return getMajorityClass(k, neighbors, classes);
}
}
```
使用示例:
```java
public class KNNTest {
public static void main(String[] args) {
KNN knn = new KNN();
// 训练数据
List<double[]> trainData = new ArrayList<>();
trainData.add(new double[]{1, 1});
trainData.add(new double[]{1, 2});
trainData.add(new double[]{2, 2});
trainData.add(new double[]{2, 3});
trainData.add(new double[]{3, 3});
trainData.add(new double[]{3, 4});
// 训练数据对应的类别
List<Integer> trainClasses = new ArrayList<>();
trainClasses.add(0);
trainClasses.add(0);
trainClasses.add(0);
trainClasses.add(1);
trainClasses.add(1);
trainClasses.add(1);
// 测试数据
double[] testPoint = new double[]{2.5, 3};
// k取3
int k = 3;
int result = knn.knn(k, testPoint, trainData, trainClasses);
System.out.println(result);
}
}
```
输出结果为1,表示测试数据属于类别1。