k-means算法Java一维实现

这里的程序稍微有点变形。

k_means方法返回K-means聚类的若干中心点。

代码:

import java.util.ArrayList;
import java.util.Collections;

public class Prophet_kmeans {
    private static final int MAXK = 100;
    
    private static int n = 0;
    private static int K = 0;
    
    private static ArrayList<Double> k_means(ArrayList<Double> list) {
        n = list.size();
        K = MAXK;
        if(K > n/4) K = n/2;
        if(K == 0) K = n;
        Collections.sort(list);
        double[] point = new double[n];
        int id = 0;
        double minn = list.get(0), maxx = list.get(0);
        for(double number : list) {
            if(number < minn) minn = number;
            if(number > maxx) maxx = number;
            point[id++] = number;
        }
        double[] center = new double[K];
        double delta = (maxx-minn) / K;
        double p =minn + delta / 2;
        for(int i=0;i<K;i++) {
            center[i] = p;
            p += delta;
        }
        int[] belong = new int[n];
        for(int i=0;i<n;i++) belong[i] = 0;
        double[] x_total = new double[K];
        int[] x_cnt = new int[K];
        for(int T=0;T<10;T++) {    // T代表k means聚类迭代的次数
            for(int i=0;i<n;i++) {
                double min_dist = Math.abs(point[i] - center[belong[i]]);
                for(int k=0;k<K;k++) {
                    double tmp_dist = Math.abs(point[i] - center[k]);
                    if(tmp_dist < min_dist) {
                        min_dist = tmp_dist;
                        belong[i] = k;
                    }
                }
            }
            for(int i=0;i<K;i++) {
                x_total[i] = 0;
                x_cnt[i] = 0;
            }
            for(int i=0;i<n;i++) {
                int k = belong[i];
                x_total[k] += point[i];
                x_cnt[k] ++;
            }
            for(int k=0;k<K;k++) {
                if(x_cnt[k] == 0) continue;
                center[k] = x_total[k] / (double)x_cnt[k];
            }
        }
        ArrayList<Double> k_list = new ArrayList<Double>();
        for(int k=0;k<K;k++) if(x_cnt[k] != 0) k_list.add(center[k]);
        K = k_list.size();
        return k_list;
    }
    
    public static void main(String[] args) {
        ArrayList<Double> list = new ArrayList<Double>();
        list.add(1.0);
        list.add(2.0);
        list.add(3.0);
        list.add(2.2);
        list.add(2.1);
        list.add(1.5);
        list.add(9.9);
        list.add(7.5);
        list.add(8.8);
        list.add(6.9);
        list.add(8.7);
        ArrayList<Double> ansList = k_means(list);
        System.out.println("K == " + K);
        for(double number : ansList) {
            System.out.println(number);
        }
    }
    
}

输出结果如下:

K == 4

1.25

2.325

7.2

9.133333333333333