java代码实现Kmeans聚类算法,训练集的文件类型为txt,java所有代码都写在k3.java中
时间: 2024-05-14 12:16:56 浏览: 211
好的,让我来回答您的问题。
首先,我要先说明一下,Kmeans聚类算法是一种无监督学习的算法,用于将样本数据分成多个类别,每个类别内部的数据相似度较高,在相应的维度上距离较近。以下是一个简单的Java代码实现,可以帮助您实现Kmeans聚类算法。
```
import java.io.BufferedReader;
import java.io.FileReader;
import java.io.IOException;
import java.util.ArrayList;
import java.util.List;
public class k3 {
static class Point {
private double x;
private double y;
public Point(double x, double y) {
this.x = x;
this.y = y;
}
public double getX() {
return x;
}
public void setX(double x) {
this.x = x;
}
public double getY() {
return y;
}
public void setY(double y) {
this.y = y;
}
public String toString() {
return "(" + x + "," + y + ")";
}
}
public static void main(String[] args) {
List<Point> dataSet = readDataSet("trainSet.txt");
List<Point> centroids = initCentroids(dataSet, 3); //分成三类
// 初始化聚类结果
int[] clusterAssment = new int[dataSet.size()];
boolean clusterChanged = true;
while (clusterChanged) {
clusterChanged = false;
//对数据集中每一个点做聚类
for (int i = 0; i < dataSet.size(); i++) {
int minIndex = -1;
double minDist = Double.MAX_VALUE;
//找出距离数据点最近的中心点
for (int j = 0; j < centroids.size(); j++) {
double distance = calcDistance(dataSet.get(i), centroids.get(j));
if (distance < minDist) {
minDist = distance;
minIndex = j;
}
}
//如果该数据点所属的中心点变化了,需要重新进行聚类
if (clusterAssment[i] != minIndex) {
clusterChanged = true;
clusterAssment[i] = minIndex;
}
}
//重新计算中心点
for (int j = 0; j < centroids.size(); j++) {
double sumX = 0;
double sumY = 0;
double count = 0;
for (int k = 0; k < dataSet.size(); k++) {
if (clusterAssment[k] == j) {
sumX += dataSet.get(k).getX();
sumY += dataSet.get(k).getY();
count++;
}
}
if (count > 0) {
centroids.get(j).setX(sumX / count);
centroids.get(j).setY(sumY / count);
}
}
}
for (int i = 0; i < dataSet.size(); i++) {
System.out.println(dataSet.get(i) + "所属的类别为:" + clusterAssment[i]);
}
}
//读取数据
static List<Point> readDataSet(String fileName) {
List<Point> dataSet = new ArrayList<Point>();
BufferedReader br = null;
try {
br = new BufferedReader(new FileReader(fileName));
String line = null;
while ((line = br.readLine()) != null) {
String[] tmp = line.split("\t");
double x = Double.parseDouble(tmp[0]);
double y = Double.parseDouble(tmp[1]);
dataSet.add(new Point(x, y));
}
} catch (IOException e) {
e.printStackTrace();
} finally {
try {
br.close();
} catch (IOException e) {
e.printStackTrace();
}
}
return dataSet;
}
//初始化质心
static List<Point> initCentroids(List<Point> dataSet, int k) {
List<Point> centroids = new ArrayList<Point>();
for (int i = 0; i < k; i++) {
Point centroid = dataSet.get((int) (Math.random() * dataSet.size() - 1));
centroids.add(centroid);
}
return centroids;
}
//计算两点之间的距离
static double calcDistance(Point p1, Point p2) {
return Math.sqrt(Math.pow(p1.getX() - p2.getX(), 2) + Math.pow(p1.getY() - p2.getY(), 2));
}
}
```
在这个示例中,我们假设我们的训练数据集trainSet.txt是一个包含两列数据的文本文件,每行包含两个浮点数,分别表示数据点在 X 和 Y 方向上的坐标。在代码中,我们使用了一个名为Point的内部类来表示每个数据点。然后,我们定义了一个readDataSet()方法来读取数据集,initCentroids()方法来初始化质心,calcDistance()方法来计算两点之间的距离。最后,我们使用Kmeans算法对数据集进行聚类,并输出每个数据点所属于的类别。
希望这个例子可以帮到您,若有不理解的地方,可以随时问我。
阅读全文