Java 版本tensorflow模型推理实现,基于bert命名实体、基于transform文本分类

最近在做文本分类任务,由于在实际工程中需要用服务对外提供功能,故采用java调用pb模型完成推理,特将过程记录如下:

1. transform文本分类

package com.techwolf.transformer;

import com.alibaba.fastjson.*;
import com.alibaba.fastjson.parser.Feature;
import org.tensorflow.Graph;
import org.tensorflow.Session;
import org.tensorflow.Tensor;
//import com.alibaba.fastjson.JSONPObject;

//import org.json.JSONObject;

import java.io.*;
import java.nio.file.Files;
import java.nio.file.Path;
import java.nio.file.Paths;
import java.util.*;

public class JobPredict {
    private static String jsonPath = "src/main/resources/resource.json";
    private static String modelPath = "src/main/resources/model.pb";
    private static Map<String, Object> positionToFeature = new HashMap<String, Object>();
    private static Map<String, Object> jobMapping = new HashMap<String, Object>();
    private static Map<String, Object> mergeMapping = new HashMap<String, Object>();
    private static Map<String, Object> featureToId = new HashMap<String, Object>();
    private static Map<String, Object> idToCode = new HashMap<String, Object>();
    private static Map<String, Object> codeToLabel = new HashMap<String, Object>();

    public static String readJsonFile(String fileName) throws FileNotFoundException {
        String jsonStr = "";
        try {
            File jsonFile = new File(fileName);
            FileReader fileReader = new FileReader(jsonFile);
            Reader reader = new InputStreamReader(new FileInputStream(jsonFile), "utf-8");
            int ch = 0;
            StringBuffer sb = new StringBuffer();
            while ((ch = reader.read()) != -1) {
                sb.append((char) ch);
            }
            fileReader.close();
            reader.close();
            jsonStr = sb.toString();
            return jsonStr;
        } catch (IOException e) {
            e.printStackTrace();
            return null;
        }
    }

    private static Map<String, Object> jsonTOMap(JSONObject jsobj) {
        Map<String, Object> data = new HashMap<String, Object>();
        Iterator it = jsobj.entrySet().iterator();
        while (it.hasNext()) {
            Map.Entry<String, Object> entry = (Map.Entry<String, Object>) it.next();
            data.put(entry.getKey(), entry.getValue());
        }
        return data;
    }

    private static void getConfig() throws FileNotFoundException {
        String jsonStr = readJsonFile(jsonPath);
        JSONObject obj = JSON.parseObject(jsonStr);

        positionToFeature = jsonTOMap(obj.getJSONObject("position2feature"));
        featureToId = jsonTOMap(obj.getJSONObject("feature2id"));
        jobMapping = jsonTOMap(obj.getJSONObject("job_mapping"));
        mergeMapping = jsonTOMap(obj.getJSONObject("merge_mapping"));
        idToCode = jsonTOMap(obj.getJSONObject("id2position"));
        codeToLabel = jsonTOMap(obj.getJSONObject("position_mapping"));
        System.out.println("config data loaded!");
    }

    public static String convert(String utfString) {
        StringBuilder sb = new StringBuilder();
        int i = -1;
        int pos = 0;
        int iint = 0;
        while ((i = utfString.indexOf("\\u", pos)) != -1) {
            String sd = utfString.substring(pos, i);
            sb.append(sd);
            iint = i + 5;

            if (iint < utfString.length()) {
                pos = i + 6;
                sb.append((char) Integer.parseInt(utfString.substring(i + 2, i + 6), 16));
            }
        }
        String endStr = utfString.substring(iint + 1, utfString.length());
        return sb + "" + endStr;
    }

    private static Map<String, List> getCodeAndScore(JSONArray jsonArray) throws FileNotFoundException {
        List<Integer> codes = new ArrayList<Integer>();
        List<Float> scores = new ArrayList<Float>();
        Integer codeFlag = -1;
        float scoreFlag = (float) .0;

        for (int i = 0; i < jsonArray.size(); i++) {
            JSONObject skillsItem = (JSONObject) jsonArray.get(i);
            String code = (skillsItem.get("code")).toString();
            Float score = Float.parseFloat((String) skillsItem.get("score"));
            boolean isReplace = mergeMapping.containsKey(code);
            if (isReplace) {
                code = (mergeMapping.get(code)).toString();
                System.out.println("replace id:" + code);
            }
            String position = (jobMapping.get(code)).toString();
            Integer featSeq = (Integer) positionToFeature.get(position);
            if (featSeq == null) {
                codes.add((Integer) featureToId.get(codeFlag.toString()));
                scores.add(scoreFlag);
            } else {
                Integer x = (Integer) featureToId.get(featSeq.toString());
                codes.add((Integer) featureToId.get(featSeq.toString()));
                scores.add(score);
            }
        }
        if (jsonArray.size() < 3) {
            for(int i=0; i< (3-jsonArray.size()); i++) {
                codes.add((Integer) featureToId.get(codeFlag.toString()));
                scores.add(scoreFlag);
            }
        }
        Map<String, List> result = new HashMap<String, List>();
        result.put("codes", codes);
        result.put("scores", scores);
        return result;
    }

    private static byte[] readAllByteOrExit(Path path){
        try{
            return Files.readAllBytes(path);
        }catch (IOException e){
            System.out.println("Failed to read[" + path + "]:" + e.getMessage());
            System.exit(1);
        }
        return null;
    }

    private static Map<String, List> getDataContent(String testFile) throws FileNotFoundException {
        String jsonStr = readJsonFile(testFile);
        JSONObject obj = JSON.parseObject(jsonStr, Feature.OrderedField);
        JSONObject objNew = JSON.parseObject(obj.toJSONString(), Feature.OrderedField);
        ArrayList<List> sampleCode = new ArrayList<List>();
        ArrayList<List> sampleScore = new ArrayList<List>();
        Map<String, List> samples = new HashMap<String, List>();

        for (String userId: objNew.keySet()) {
            ArrayList<List> codeList = new ArrayList<List>();
            ArrayList<Double> scoresList = new ArrayList<Double>();
            JSONObject itemTags = (JSONObject) ((JSONObject)((JSONObject)objNew.get(userId)).get("_source")).get("tags");
            JSONArray skills = (JSONArray) itemTags.get("skills");
            JSONArray title = (JSONArray) itemTags.get("title");
            JSONArray desc = (JSONArray) itemTags.get("desc");
            Map<String, List> skillsResult = getCodeAndScore(skills);
            Map<String, List> titleResult = getCodeAndScore(title);
            Map<String, List> descResult = getCodeAndScore(desc);
            codeList.addAll(skillsResult.get("codes"));
            codeList.addAll(titleResult.get("codes"));
            codeList.addAll(descResult.get("codes"));
            scoresList.addAll(skillsResult.get("scores"));
            scoresList.addAll(titleResult.get("scores"));
            scoresList.addAll(descResult.get("scores"));
            sampleCode.add(codeList);
            sampleScore.add(scoresList);
        }
        samples.put("sampleCode", sampleCode);
        samples.put("sampleScore", sampleScore);
        System.out.println("ok! sample feature created.");
        return samples;
    }

    public static int[] arraySort(float[] arr, boolean desc) {
        float temp;
        int index;
        int k = arr.length;
        int[] Index = new int[k];
        for (int i = 0; i < k; i++) {
            Index[i] = i;
        }

        for (int i = 0; i < arr.length; i++) {
            for (int j = 0; j < arr.length - i - 1; j++) {
                if (desc) {
                    if (arr[j] < arr[j + 1]) {
                        temp = arr[j];
                        arr[j] = arr[j + 1];
                        arr[j + 1] = temp;

                        index = Index[j];
                        Index[j] = Index[j + 1];
                        Index[j + 1] = index;
                    }
                } else {
                    if (arr[j] > arr[j + 1]) {
                        temp = arr[j];
                        arr[j] = arr[j + 1];
                        arr[j + 1] = temp;

                        index = Index[j];
                        Index[j] = Index[j + 1];
                        Index[j + 1] = index;
                    }
                }
            }
        }
        return Index;
    }


    private static void featToTensor(float[][][] indexes, int[][] codes, float[][] scores, Map<String, List> data) {

        List<Integer> featCode = data.get("sampleCode");
        List<Float> featScore = data.get("sampleScore");
        int size = 9;
        for(int i=0; i < featCode.size(); i++) {
            Object eachCode = featCode.get(i);
            Object eachScore = featScore.get(i);
            float [][] positionResult = new float[size][];
            for(int step=0; step < size; step++) {
                float[] positionVector = new float[size];
                positionVector[step] = 1;
                positionResult[step] = positionVector;
            }
            indexes[i] = positionResult;
            Integer[] targetInter = ((List<Integer>)eachCode).toArray(new Integer[size]);
            int[] codeResult = Arrays.stream(targetInter).mapToInt(Integer::valueOf).toArray();
            Float[] targetFloat = ((List<Float>)eachScore).toArray(new Float[size]);
            double[] scoreResult = Arrays.stream(targetFloat).mapToDouble(Double::valueOf).toArray();
            float[] scoreFloat = new float[size];
            for(int j=0; j < scoreResult.length; j++) {
                scoreFloat[j] = (float) scoreResult[j];
            }
            System.arraycopy(codeResult,0,codes[i], 0, codeResult.length);
            System.arraycopy(scoreFloat,0,scores[i], 0, scoreResult.length);

        }

    }


    private static List<HashMap<String, Float>> modelInfer(Map<String, List> data) {

        int batchSize = data.get("sampleCode").size();
        int padLength = 9;
        int returnNum = 5;
        int classNum = 868;
        float[][][] indexes = new float[batchSize][padLength][padLength];
        int[][] codes = new int[batchSize][padLength];
        float[][] scores = new float[batchSize][padLength];
        float transKeepProb = (float) 1.0;
        float multiKeepProb = (float) 1.0;

        byte[] graphDef = readAllByteOrExit(Paths.get(modelPath));
        Graph g = new Graph();
        g.importGraphDef(graphDef);
        Session sess = new Session(g);

        featToTensor(indexes, codes, scores, data);
        Tensor tensorIndex = Tensor.create(indexes);
        Tensor tensorCode = Tensor.create(codes);
        Tensor tensorScore = Tensor.create(scores);
        Tensor tensorTransProb = Tensor.create(transKeepProb);
        Tensor tensorMultiProb = Tensor.create(multiKeepProb);
        Tensor tensorClassResult = sess.runner().
                feed("input_x:0", tensorCode).
                feed("input_x_score:0", tensorScore).
                feed("embed_position:0", tensorIndex).
                feed("trans_keep_prob:0", tensorTransProb).
                feed("multi_keep_prob:0", tensorMultiProb).
                fetch("discriminator/softmax_score:0").run().get(0);

        float[][] result = (float[][]) tensorClassResult.copyTo(new float[batchSize][classNum]);
        List<HashMap<String, Float>> predictResult = new ArrayList();
        for(int i=0; i<result.length; i++){
            float[] resultVec = result[i];
            int[] resultIndex = new int[classNum];
            HashMap<String, Float> predictSample = new HashMap<String, Float>();
            resultIndex = arraySort(resultVec, true);
            for(int s=0; s < returnNum; s++) {
                String sampleCode = Integer.toString(resultIndex[s]);
                String label = (String) codeToLabel.get(Integer.toString((Integer) idToCode.get(sampleCode)));
                predictSample.put(label, resultVec[s]);
            }
            predictResult.add(predictSample);
        }
        tensorClassResult.close();
        tensorMultiProb.close();
        tensorTransProb.close();
        tensorScore.close();
        tensorCode.close();
        tensorIndex.close();
        return predictResult;
    }

        public static void main (String[]args) throws IOException {
            String testFile = "src/main/data/predict_data.json";

            getConfig();
            Map<String, List> samples = getDataContent(testFile);
            List<HashMap<String, Float>> result = modelInfer(samples);

            System.out.println(result);
        }

}

2. 基于bert的ner

package com.techwolf.bert;

import org.tensorflow.Graph;
import org.tensorflow.Session;
import org.tensorflow.Tensor;

import java.io.BufferedReader;
import java.io.FileInputStream;
import java.io.IOException;
import java.io.InputStreamReader;
import java.nio.file.Files;
import java.nio.file.Path;
import java.nio.file.Paths;
import java.util.*;


public class BertNerPredict {
    private static String vocabPath = "src/main/resources/vocab.txt";
    private static Map<String, Integer> word2id = new HashMap<String, Integer>();
    static {
        try {
            BufferedReader buffer = null;
            buffer = new BufferedReader(new InputStreamReader(new FileInputStream(vocabPath)));
            int i = 0;
            String line = buffer.readLine().trim();
            while (line!=null){
                word2id.put(line, i++);
                line = buffer.readLine().trim();
            }
            buffer.close();
        }catch (Exception e){
        }
//        System.out.println("word2id size is:"+word2id.size());

    }

    private static byte[] readAllByteOrExit(Path path){
        try{
            return Files.readAllBytes(path);
        }catch (IOException e){
            System.out.println("Failed to read[" + path + "]:" + e.getMessage());
            System.exit(1);
        }
        return null;
    }

    public static void getTextToId(int[][] inputIds, int[][] inputMask, String[] text){
        for(int i=0; i<text.length; i++){
            char[] chs = text[i].trim().toLowerCase().toCharArray();

            List<Integer> list = new ArrayList<>();
            List<Integer> mask = new ArrayList<>();
            list.add(word2id.get("[CLS]"));
            mask.add(1);
            for(int j=0; j<chs.length; j++){
                String element = Character.toString(chs[j]);
                if(word2id.containsKey(element)){
                    list.add(word2id.get(element));
                    mask.add(1);
                }
            }
            list.add(word2id.get("[SEP]"));
            mask.add(1);

            int size = list.size();
            Integer[] targetInter = list.toArray(new Integer[size]);
            int[] idResult = Arrays.stream(targetInter).mapToInt(Integer::valueOf).toArray();
            Integer[] maskInter = mask.toArray(new Integer[size]);
            int[] maskResult = Arrays.stream(maskInter).mapToInt(Integer::valueOf).toArray();
            System.arraycopy(idResult,0,inputIds[i], 0, idResult.length);
            System.arraycopy(maskResult,0,inputMask[i], 0, maskResult.length);
        }
    }

    public static void main(String[] args) {
        String[] query = new String[]{"中华人民共和国", "新疆大学"};
        String resourceDir = "src/main/resources";
        String modelName = "model.pb";

        int batchSize = query.length;
        int padLength = 25;
        int[][] indexes = new int[batchSize][padLength];
        int[][] mask = new int[batchSize][padLength];

        byte[] graphDef = readAllByteOrExit(Paths.get(resourceDir, modelName));
        Graph g = new Graph();
        g.importGraphDef(graphDef);
        Session sess = new Session(g);

        if (query.length>0){
            System.out.println("Ok! Start predicting...\n");
        }else {
            System.exit(0);
        }

        getTextToId(indexes, mask, query);
        Tensor tensorInputIds = Tensor.create(indexes);
        Tensor tensorMask = Tensor.create(mask);
        Tensor tensorSeqResult = sess.runner().feed("input_ids:0", tensorInputIds).
                feed("input_mask:0", tensorMask).fetch("viterbi/ReverseSequence_1:0").run().get(0);
        Tensor tensorScoreResult = sess.runner().feed("input_ids:0", tensorInputIds).
                feed("input_mask:0", tensorMask).fetch("viterbi/Max:0").run().get(0);
        int[][] sequenceId = (int[][]) tensorSeqResult.copyTo(new int[batchSize][padLength]);
        float[] sequenceScore = (float[]) tensorScoreResult.copyTo(new float[batchSize]);
        for(int i=0; i<sequenceId.length; i++){
            System.out.println("query: "+query[i]);
            System.out.println("sequence result: "+ Arrays.toString(sequenceId[i]));
            System.out.println("sequence score: "+ sequenceScore[i]+'\n');
        }
        tensorScoreResult.close();
        tensorSeqResult.close();
        tensorMask.close();
        tensorInputIds.close();
    }
}