java代码实现Kmeans聚类算法,训练集的文件类型为txt,java所有代码都写在k3.java中
时间: 2024-05-11 21:15:48 浏览: 6
好的,我明白了。下面是一个简单的Java代码实现Kmeans聚类算法,训练集的文件类型为txt,所有代码都写在k3.java中:
```
import java.util.*;
import java.io.*;
public class k3 {
// 定义聚类的个数
static int k = 3;
// 定义输入的文件名
static String inputFile = "train_data.txt";
public static void main(String[] args) {
// 读取输入文件中的数据
double[][] data = readData(inputFile);
// 初始化聚类中心
double[][] centroids = initCentroids(data);
// 迭代计算聚类结果
int maxIter = 500;
kmeans(data, centroids, maxIter);
}
// 读取输入文件中的数据
public static double[][] readData(String fileName) {
ArrayList<double[]> dataList = new ArrayList<double[]>();
try {
File file = new File(fileName);
BufferedReader reader = new BufferedReader(new FileReader(file));
String line = null;
while ((line = reader.readLine()) != null) {
String[] values = line.split(",");
double[] data = new double[values.length];
for (int i = 0; i < values.length; i++) {
data[i] = Double.parseDouble(values[i]);
}
dataList.add(data);
}
reader.close();
} catch (Exception e) {
e.printStackTrace();
}
double[][] dataArr = new double[dataList.size()][dataList.get(0).length];
for (int i = 0; i < dataList.size(); i++) {
dataArr[i] = dataList.get(i);
}
return dataArr;
}
// 初始化聚类中心
public static double[][] initCentroids(double[][] data) {
double[][] centroids = new double[k][data[0].length];
int[] index = new int[k];
for (int i = 0; i < k; i++) {
int randomIndex = (int) (Math.random() * data.length);
if (Arrays.binarySearch(index, randomIndex) >= 0) {
i--;
continue;
}
centroids[i] = data[randomIndex];
index[i] = randomIndex;
}
return centroids;
}
// 更新聚类中心
public static double[][] updateCentroids(double[][] data, int[] labels) {
double[][] centroids = new double[k][data[0].length];
int[] count = new int[k];
for (int i = 0; i < data.length; i++) {
int label = labels[i];
for (int j = 0; j < data[i].length; j++) {
centroids[label][j] += data[i][j];
}
count[label] += 1;
}
for (int i = 0; i < k; i++) {
for (int j = 0; j < centroids[i].length; j++) {
centroids[i][j] /= count[i];
}
}
return centroids;
}
// 计算欧几里得距离
public static double dist(double[] v1, double[] v2) {
double distance = 0;
for (int i = 0; i < v1.length; i++) {
double diff = v1[i] - v2[i];
distance += diff * diff;
}
return Math.sqrt(distance);
}
// 分配数据点到最近的聚类中心
public static int[] assignLabels(double[][] data, double[][] centroids) {
int[] labels = new int[data.length];
for (int i = 0; i < data.length; i++) {
double minDist = Double.MAX_VALUE;
for (int j = 0; j < centroids.length; j++) {
double distance = dist(data[i], centroids[j]);
if (distance < minDist) {
minDist = distance;
labels[i] = j;
}
}
}
return labels;
}
// Kmeans聚类
public static void kmeans(double[][] data, double[][] centroids, int maxIter) {
int[] labels = new int[data.length];
for (int iteration = 1; iteration <= maxIter; iteration++) {
labels = assignLabels(data, centroids);
centroids = updateCentroids(data, labels);
}
System.out.println("Cluster assignments: " + Arrays.toString(labels));
System.out.println("Final centroids: " + Arrays.deepToString(centroids));
}
}
```
这里需要注意几点:
1. 代码中使用了随机初始化聚类中心的方法,需要注意初始化随机种子的问题。
2. 代码中计算距离使用了欧几里得距离,对于不同的数据类型或数据范围,可能需要使用其他距离度量方法。
3. 代码中使用了给定的K值,如果不确定K值,可以使用Elbow Method等方法来确定最优的K值。