感知机学习算法Java实现

感知机学习算法Java实现。

Perceptron类用于实现感知机,

其中的perceptronOriginal()方法用于实现感知机学习算法的原始形式;

perceptronAnother()方法用于实现感知机学习算法的对偶形式(此处仍有bug)。

import java.util.Scanner;


public class Perceptron {
    private static final int maxn = 1010;
    private static final int maxm = 101;
    private static double[][] x = new double[maxn][maxm];
    private static double[] y = new double[maxn];
    private static int N = 0;
    private static int M = 0;
    private static double phi = 0.1;
    
    private static double[] w = new double[maxm];
    private static double b = 0;
    
    private static double[] alpha = new double[maxn];
    private static double[][] G = new double[maxn][maxn];
    
    private static void perceptronOriginal() {
        for(int i=0;i<M;i++) w[i] = 0;
        b = 0;
        boolean ok = true;
        while(ok) {
            ok = false;
            for(int i=0;i<N;i++) {
                double tmp = 0;
                for(int j=0;j<M;j++) tmp += w[j] * x[i][j];
                tmp += b;
                if(tmp * y[i] <= 0) {
                    ok = true;
                    for(int j=0;j<M;j++) w[j] += phi * y[i] * x[i][j];
                    b += phi * y[i];
                }
            }
        }
    }
    
    private static void beforePerceptronAnother() {
        for(int i=0;i<N;i++)
            for(int j=i;j<N;j++)
                for(int k=0;k<M;k++)
                    G[i][j] = x[i][k] * x[j][k];
        for(int i=0;i<N;i++)
            for(int j=0;j<i;j++)
                G[i][j] = G[j][i];
    }
    
    public static void perceptronAnother() { // has bug
        beforePerceptronAnother();
        for(int i=0;i<N;i++) alpha[i] = 0;
        b = 0;
        boolean ok = true;
        while(ok) {
            ok = false;
            for(int i=0;i<N;i++) {
                double tmp = 0;
                for(int j=0;j<N;j++)
                    tmp += alpha[j] * y[j] * G[j][i];
                tmp += b;
                if(y[i] * tmp <= 0) {
                    ok = true;
                    alpha[i] += phi;
                    b += phi * y[i];
                    System.out.println("alpha[" + i + "]:" + alpha[i] + ",b:" + b);
                }
            }
        }
        phi *= 0.9;
    }
    
    public static void main(String[] args) {
        Scanner in = new Scanner(System.in);
        System.out.print("input N: ");
        N = in.nextInt();
        System.out.print("input M: ");
        M = in.nextInt();
        System.out.println("input x:");
        for(int i=0;i<N;i++) 
            for(int j=0;j<M;j++)
                x[i][j] = in.nextDouble();
        System.out.println("input y(y={-1,+1}):");
        for(int i=0;i<N;i++)
            y[i] = in.nextDouble();
        System.out.println("perceptron original......");
        perceptronOriginal();
        for(int i=M-1;i>=0;i--)
            System.out.println("w" + i + ":" + w[i]);
        System.out.println("b:" + b);
        System.out.println("perceptron another......");
        perceptronAnother();
        for(int i=0;i<M;i++) w[i] = 0;
        for(int i=0;i<N;i++) {
            for(int j=0;j<M;j++) {
                w[j] += alpha[i] * y[i] * x[i][j];
            }
        }
        for(int i=M-1;i>=0;i--)
            System.out.println("w" + i + ":" + w[i]);
        System.out.println("b:" + b);
    }
}