稀疏矩阵乘法加法等的java实现

原创声明:本文系作者原创,转载请写明出处。

一、前言

前几天由于科研需要,一直在搞矩阵的稀疏表示的乘法,不过最近虽然把程序写出来了,还是无法处理大规模的矩阵(虽然已经是稀疏了)。原因可能是结果不够稀疏。或者相乘的矩阵本来也不稀疏。

还是把实现的程序放在这里。以供以后研究使用。

二、程序实现功能

首先封装稀疏矩阵为三元组形式。

程序的主要功能有:

稀疏矩阵的转置

稀疏矩阵的乘法

稀疏矩阵的加法

以及相应的导入文本文件(矩阵)等。

三、代码展示

以下程序由eclipse下编写的java

package others;

import java.io.BufferedReader;
import java.io.File;
import java.io.FileOutputStream;
import java.io.FileReader;
import java.io.IOException;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.Iterator;
import java.util.Map;
import java.util.Map.Entry;

import weka.clusterers.SimpleKMeans;
import weka.core.DistanceFunction;
import weka.core.Instances;
import weka.core.converters.ArffLoader;
import Jama.Matrix;
/*
 * 本类可实现稀疏矩阵三元组表示下的矩阵乘法和矩阵加法,以及矩阵转置等。结果也是三元组存储。
 * 但是当数据量非常庞大时,乘积的结果无法存储,会出现内存溢出的现象。
 */
public class SMatrix {

        public Map<ArrayList<Integer>,Integer> Triples;//矩阵的三元组表示
        public int rowNum;//矩阵行数
        public int colNum;//矩阵列数
        
        
        public int getRowNum() {
                return rowNum;
        }


        public void setRowNum(int rowNum) {
                this.rowNum = rowNum;
        }


        public int getColNum() {
                return colNum;
        }


        public void setColNum(int colNum) {
                this.colNum = colNum;
        }


        /*
         * 构造函数1
         */
        public SMatrix(){
                
        }
        
        
        /*
         * 构造函数2
         */
        public SMatrix(Map<ArrayList<Integer>, Integer> triples, int rowNum, int colNum) {

                Triples = triples;
                this.rowNum = rowNum;
                this.colNum = colNum;
        }
        
        /*
         * 构造函数3
         */
        public SMatrix(Map<ArrayList<Integer>, Integer> triples) {

                Triples = triples;
        }
        
        /*
         * 稀疏矩阵相乘函数
         */
        public SMatrix Multiply(SMatrix M,SMatrix N){
                if(M.colNum != N.rowNum){
                        System.out.println("矩阵相乘不满足条件");
                        return null;
                }               
                
                Map<ArrayList<Integer>,Integer> triples = new HashMap<ArrayList<Integer>,Integer>();
                Iterator<Map.Entry<ArrayList<Integer>, Integer>> it1 = M.Triples.entrySet().iterator();
                
                
                int iter = 0;
                while(it1.hasNext()){
                        iter++;
//                      System.out.println("迭代次数:"+iter);
                        Entry<ArrayList<Integer>, Integer> entry = it1.next();
                        ArrayList<Integer> position = entry.getKey();
//                      System.out.println("检查程序:" + position);
                        int value = entry.getValue();   
                        int flag = 0;
                        Iterator<Map.Entry<ArrayList<Integer>, Integer>> it2 = N.Triples.entrySet().iterator();
                        while(it2.hasNext()){
                                Entry<ArrayList<Integer>,Integer> entry2 = it2.next();
                                ArrayList<Integer> position2 = entry2.getKey();
                                int value2 = entry2.getValue();
                                
                                
                                if(position.get(1) == position2.get(0)){
                                        flag = 1;
                                        ArrayList<Integer> temp = new ArrayList<Integer>();
                                        temp.add(position.get(0));
                                        temp.add(position2.get(1));
                                        int v = value * value2;
                                        if(triples.containsKey(temp)){
                                                triples.put(temp, triples.get(temp) + v);
                                                System.out.println(temp+ "\t"+(triples.get(temp) + v));
                                                
                                        }
                                        else{
                                                triples.put(temp, v);
                                                System.out.println(temp + "\t" + v);
                                        }                               
                                }
                                
                        }       
                }       
                SMatrix s = new SMatrix(triples,M.rowNum,N.colNum);
                return s;
        }
        
        
        
        /*
         * 稀疏矩阵相加函数
         */
        public static SMatrix Add(SMatrix M,SMatrix N){
                if(M.colNum != N.colNum || M.rowNum != N.rowNum){
                        System.out.println("矩阵相加不满足条件");
                        return null;
                }
                SMatrix s = new SMatrix();
                Map<ArrayList<Integer>,Integer> triples = new HashMap<ArrayList<Integer>,Integer>();
                Iterator<Map.Entry<ArrayList<Integer>, Integer>> it1 = M.Triples.entrySet().iterator();
                Iterator<Map.Entry<ArrayList<Integer>, Integer>> it2 = N.Triples.entrySet().iterator();
                
                while(it1.hasNext()){
                        Entry<ArrayList<Integer>, Integer> entry = it1.next();
                        ArrayList<Integer> position = entry.getKey();
                        int value = entry.getValue();
                        if(triples.containsKey(position)){
                                triples.put(position, triples.get(position) + value);
                        }else{
                                triples.put(position, value);
                        }
                        
                        
                }
                
                while(it2.hasNext()){
                        Entry<ArrayList<Integer>,Integer> entry = it2.next();
                        ArrayList<Integer> position = entry.getKey();
                        int value = entry.getValue();
                        if(triples.containsKey(position)){
                                triples.put(position, triples.get(position) + value);
                        }else{
                                triples.put(position, value);
                        }
                        
                }
                return s;
        }
        
        
        /*
         * 稀疏矩阵求转置矩阵函数
         */
        public SMatrix Transposition(){
                
                Map<ArrayList<Integer>,Integer> triples = new HashMap<ArrayList<Integer>,Integer>();
                Iterator<Map.Entry<ArrayList<Integer>, Integer>> it = this.Triples.entrySet().iterator();
                while(it.hasNext()){
                        Entry<ArrayList<Integer>, Integer> entry = it.next();
                        ArrayList<Integer> position = entry.getKey();
                        int value = entry.getValue();
                        ArrayList<Integer> transP = new ArrayList<Integer>();
                        transP.add(position.get(1));
                        transP.add(position.get(0));
                        
                        triples.put(transP, value);
                        
                }
                SMatrix s = new SMatrix(triples,this.colNum,this.rowNum);
                return s;
        }


        
        /*
         * 加载文本数据为稀疏矩阵三元组形式的函数
         */
        public SMatrix Load(String file, String delimeter){
                
                
                Map<ArrayList<Integer>,Integer> triples = new HashMap<ArrayList<Integer>,Integer>();
                
                try{
                        File f = new File(file);
                        FileReader fr = new FileReader(f);
                        BufferedReader br = new BufferedReader(fr);
                        
                        String line;
                        
                        while((line = br.readLine()) != null){
                                String[] str = line.trim().split(delimeter);
                                
                                ArrayList<Integer> s = new ArrayList<Integer>();
                                for(int i = 0;i < str.length - 1; i++){
                                        s.add(Integer.parseInt(str[i]));
                                }
                                
                                triples.put(s, Integer.parseInt(str[str.length - 1]));
                                
                        }
                        
                        
                        br.close();
                        fr.close();
                        
                }catch(IOException e){
                        e.printStackTrace();
                }
                SMatrix sm = new SMatrix(triples);
                return sm;
        }
        /*
         * 打印稀疏矩阵(三元组形式)
         */
        public void Print(){
                Map<ArrayList<Integer>, Integer> triples = this.Triples;
                Iterator<Map.Entry<ArrayList<Integer>, Integer>> it = triples.entrySet().iterator();
                int num = 0;
                while(it.hasNext()){
                        Entry<ArrayList<Integer>, Integer> entry = it.next();
                        ArrayList<Integer> position = entry.getKey();
                        num++;
                        System.out.print(num+":");
                        for(Integer in:position){
                                System.out.print(in + "\t");
                        }
                        
                        System.out.println(entry.getValue());
                }
                
        }
        
        
        
public static void main(String[] args){
                
        /*
         * 测试程序
         
        String testS = "data/me";
        int k = 3;
        SMatrix te = new SMatrix();
        te = te.Load(testS,"\t");
        te.rowNum = 4;
        te.colNum = 6;
        System.out.println("打印原矩阵");
        te.Print();
        System.out.println("打印原矩阵的转置矩阵");
        te.Transposition().Print();
        
        System.out.println("打印乘积矩阵");
        SMatrix A2 = new SMatrix();

        A2 = te.Multiply(te, te.Transposition());
        A2.Print();
        */

        
        
        
        
        
        long start = System.currentTimeMillis();
        
                String file1 = "data/AT.txt";//author to term 的稀疏矩阵
                String file2 = "data/CA.txt";//conference to author 的稀疏矩阵
                String delimeter = "    ";
                int k = 11;
                SMatrix M = new SMatrix();
                SMatrix MT = new SMatrix();
        
                SMatrix N = new SMatrix();
                SMatrix NT = new SMatrix();
                SMatrix P = new SMatrix();
                SMatrix Q = new SMatrix();
                
                M = M.Load(file1, delimeter);
                M.colNum = 9225;
                M.rowNum = 6456;
                System.out.println("打印矩阵M");
                M.Print();
                MT = M.Transposition();
                System.out.println("打印矩阵MT");
                MT.Print();
                
                System.out.println("计算M和MT的乘积");
                System.out.println(M.rowNum);
                P = M.Multiply(M, MT);
                System.out.println("打印矩阵M与矩阵M转置的乘积");
                P.Print();
                
                
                
                
                N = N.Load(file2, delimeter);
                N.colNum = 6456;
                N.rowNum = 20;
                System.out.println("打印矩阵N");
                N.Print();
                NT = N.Transposition();
                
                System.out.println("打印矩阵NT:");
                NT.Print();
                
                System.out.println("计算NT 和  N的乘积");
                System.out.println(NT.colNum);
                System.out.println(N.rowNum);
                Q = M.Multiply(NT, N);          
                Q.Print();
                
                
                
                
                
                
                SMatrix A = new SMatrix();
                A = A.Load("data/AA.txt","      ");
                
                SMatrix A1 = new SMatrix();
                SMatrix A2 = new SMatrix();
                System.out.println("计算矩阵A1=P+Q:");
                A1 = SMatrix.Add(Q, P);
                
                System.out.println("打印矩阵A1:");
                A1.Print();
                A2 = SMatrix.Add(A1, A);//得到了比较全面的author to author 矩阵三元组
                
                A2.Print();
                
                
                double[][] matrix = new double[A2.rowNum][A2.colNum];
                
                for(int i = 0;i < A2.rowNum;i++){
                        for (int j = 0; j < A2.colNum; j++) {
                                
                                ArrayList<Integer> list = new ArrayList<Integer>();
                                list.add(i);
                                list.add(j);

                                if (A2.Triples.containsKey(list)) {
                                        matrix[i][j] = A2.Triples.get(list);
                                }
                                else{
                                        matrix[i][j] = 0;       
                                }
                                
                        }
                }


                for(int i = 0;i<A2.rowNum;i++){
                        for(int j = 0;j < A2.colNum;j++){
                                System.out.print(matrix[i][j]+"\t");
                        }
                        System.out.println();
                }
                Matrix Author = new Matrix(matrix);

                //第二步:求矩阵的特征值eigValue及其相应的特征向量矩阵,取前K个(最大的)
                Matrix diagA = Author.eig().getD();

                diagA.print(4, 2);
                int m = diagA.getRowDimension();
                int n = diagA.getColumnDimension();
                
                Matrix eigVector = Author.eig().getV();
                
                eigVector.print(eigVector.getRowDimension(),4);


                //将特征向量输出到文本中。
                String outFile = "data/eigenVector.txt";
                try{
                        File f = new File(outFile);
                        FileOutputStream fout = new FileOutputStream(f);
                        
                        fout.write("@RELATION\teigenVector\n".getBytes());
                        for(int i = n-k;i<n;i++){
                                fout.write(("@ATTRIBUTE\t"+i + "\tREAL\n").getBytes());
                        }
                        fout.write("@DATA\n".getBytes());
                        if(k <= n){
                                for(int i = 0;i < m;i++){
                                        for(int j = n-k;j<n;j++){
                                                Double temp = new Double(eigVector.getArray()[i][j]);
                                                String tem = temp.toString();
                                                fout.write((tem + "\t").getBytes());
                                                
                                        }
                                        fout.write(("\n").getBytes());
                                }
                        }
                }
                catch(IOException e){
                        e.printStackTrace();
                }
                //第三步:对特征向量矩阵进行kmeans聚类
                Instances ins = null;
                
                SimpleKMeans KM = null;
                
                // 目前没有使用到,但是在3.7.10的版本之中可以指定距离算法
                // 默认是欧几里得距离
                DistanceFunction disFun = null;
                
                try {
                        // 读入样本数据
                        File file = new File("data/eigenVector.txt");
                        ArffLoader loader = new ArffLoader();
                        loader.setFile(file);
                        ins = loader.getDataSet();
                        
                        // 初始化聚类器 (加载算法)
                        KM = new SimpleKMeans();
                        KM.setNumClusters(2);           //设置聚类要得到的类别数量
                        
                        KM.setMaxIterations(100);
                        KM.buildClusterer(ins);         //开始进行聚类
                        System.out.println(KM.preserveInstancesOrderTipText());
                        // 打印聚类结果
                        System.out.println(KM.toString());
                        
                
//                      for(String option : KM.getOptions()) {
//                              System.out.println(option);
//                      }
//                      System.out.println("CentroIds:" + tempIns);
                } catch(Exception e) {
                        e.printStackTrace();
                }
                
                
                System.out.println("程序正常结束");
                
                
                long end = System.currentTimeMillis();
                System.out.println(end - start);
                
                
        }
        
}