tensorflow创建cnn网络进行中文手写文字识别

数据集下载地址:http://www.nlpr.ia.ac.cn/databases/handwriting/download.html

chinese_write_detection.py

# -*- coding: utf-8 -*-
import tensorflow as tf
import os
import random
import tensorflow.contrib.slim as slim
import time
import numpy as np
import pickle
from PIL import Image
from log_utils import get_logger

logger = get_logger("HandWritten  Practice")
root_path = \'D:/eclipse-workspace/sxzsb\'
tf.app.flags.DEFINE_boolean(\'random_flip_up_down\', False, "Whether to random flip up down")
tf.app.flags.DEFINE_boolean(\'random_brightness\', True, "whether to adjust brightness")
tf.app.flags.DEFINE_boolean(\'random_contrast\', True, "whether to random constrast")

tf.app.flags.DEFINE_integer(\'charset_size\', 3755, "Choose the first `charset_size` character to conduct our experiment.")
tf.app.flags.DEFINE_integer(\'image_size\', 64, "Needs to provide same value as in training.")
tf.app.flags.DEFINE_boolean(\'gray\', True, "whether to change the rbg to gray")
tf.app.flags.DEFINE_integer(\'max_steps\', 12002, \'the max training steps \')
tf.app.flags.DEFINE_integer(\'eval_steps\', 50, "the step num to eval")
tf.app.flags.DEFINE_integer(\'save_steps\', 2000, "the steps to save")

tf.app.flags.DEFINE_string(\'checkpoint_dir\', \'D:/eclipse-workspace/sxzsb/checkpoint\', \'the checkpoint dir\')
tf.app.flags.DEFINE_string(\'train_data_dir\', \'D:/eclipse-workspace/sxzsb/data/train\', \'the train dataset dir(containing png files)\')
tf.app.flags.DEFINE_string(\'test_data_dir\', \'D:/eclipse-workspace/sxzsb/data/test\', \'the test dataset dir(containing png files)\')
tf.app.flags.DEFINE_string(\'log_dir\', \'D:/eclipse-workspace/sxzsb/log\', \'the logging path)\')

tf.app.flags.DEFINE_boolean(\'restore\', False, \'whether to restore from checkpoint\')
tf.app.flags.DEFINE_integer(\'epoch\', 1, \'Number of epoches\')
tf.app.flags.DEFINE_integer(\'batch_size\', 128, \'Validation batch size\')
tf.app.flags.DEFINE_string(\'mode\', \'train\', \'Running mode. One of {"train", "valid", "test"}\')
FLAGS = tf.app.flags.FLAGS


class DataIterator:

    def __init__(self, data_dir):
        # Set FLAGS.charset_size to a small value if available computation power is limited.
        truncate_path = data_dir + (\'%05d\' % FLAGS.charset_size)
        print(truncate_path)
        self.image_names = []
        for root, sub_folder, file_list in os.walk(data_dir):
            if root < truncate_path:  # some problem here ,because the first root is contain inside ,and there is no file_list
                self.image_names += [os.path.join(root, file_path) for file_path in file_list]
        random.shuffle(self.image_names)
        self.labels = [int(file_name[len(data_dir):].split(os.sep)[0]) for file_name in self.image_names]  # int("00020") output:20

    @property
    def size(self):  #  @property,负责把一个方法变成属性调用的,还可以定义只读属性,只定义getter方法,不定义setter方法就是一个只读属性
        return len(self.labels)

    @staticmethod
    def data_augmentation(images):
        if FLAGS.random_flip_up_down:
            images = tf.image.random_flip_up_down(images)
        if FLAGS.random_brightness:
            images = tf.image.random_brightness(images, max_delta=0.3)
        if FLAGS.random_contrast:
            images = tf.image.random_contrast(images, 0.8, 1.2)
        return images

    def input_pipeline(self, batch_size, num_epochs=None, aug=False):
        # 1、convert images to a tensor   构造数据queue
        images_tensor = tf.convert_to_tensor(self.image_names, dtype=tf.string)
        # 执行tf.convert_to_tensor()的时候,在图上生成了一个Op,Op中保存了传入参数的数据。op经过计算产生tensor
        labels_tensor = tf.convert_to_tensor(self.labels, dtype=tf.int64)
        input_queue = tf.train.slice_input_producer([images_tensor, labels_tensor], num_epochs=num_epochs)
        # 2、 ## queue输出数据
        labels = input_queue[1]
        images_content = tf.read_file(input_queue[0])  # read images from the queue,refer to input_queue
        images = tf.image.convert_image_dtype(tf.image.decode_png(images_content, channels=1), tf.float32)
        if aug:
            images = self.data_augmentation(images)
        new_size = tf.constant([FLAGS.image_size, FLAGS.image_size], dtype=tf.int32)
        images = tf.image.resize_images(images, new_size)
        # collect batches of images before processing
        # 3、shuffle_batch批量从queu批量读取数据
        image_batch, label_batch = tf.train.shuffle_batch([images, labels], batch_size=batch_size, capacity=50000,
                                                          min_after_dequeue=10000)  # produce shunffled batch
        return image_batch, label_batch


def build_graph(top_k):
    # with tf.device(\'/cpu:0\'):
    keep_prob = tf.placeholder(dtype=tf.float32, shape=[], name=\'keep_prob\')
    images = tf.placeholder(dtype=tf.float32, shape=[None, 64, 64, 1], name=\'image_batch\')
    labels = tf.placeholder(dtype=tf.int64, shape=[None], name=\'label_batch\')

    conv_1 = slim.conv2d(images, 64, [3, 3], 1, padding=\'SAME\', scope=\'conv1\')
# (inputs,num_outputs,[卷积核个数] kernel_size,[卷积核的高度,卷积核的宽]stride=1,padding=\'SAME\',)
    max_pool_1 = slim.max_pool2d(conv_1, [2, 2], [2, 2], padding=\'SAME\')
    conv_2 = slim.conv2d(max_pool_1, 128, [3, 3], padding=\'SAME\', scope=\'conv2\')
    max_pool_2 = slim.max_pool2d(conv_2, [2, 2], [2, 2], padding=\'SAME\')
    conv_3 = slim.conv2d(max_pool_2, 256, [3, 3], padding=\'SAME\', scope=\'conv3\')
    max_pool_3 = slim.max_pool2d(conv_3, [2, 2], [2, 2], padding=\'SAME\')

    flatten = slim.flatten(max_pool_3)
    fc1 = slim.fully_connected(tf.nn.dropout(flatten, keep_prob), 1024, activation_fn=tf.nn.tanh, scope=\'fc1\')
    logits = slim.fully_connected(tf.nn.dropout(fc1, keep_prob), FLAGS.charset_size, activation_fn=None, scope=\'fc2\')
   # logits = slim.fully_connected(flatten, FLAGS.charset_size, activation_fn=None, reuse=reuse, scope=\'fc\')
    loss = tf.reduce_mean(tf.nn.sparse_softmax_cross_entropy_with_logits(logits=logits, labels=labels))
   # y表示的是实际类别,y_表示预测结果,这实际上面是把原来的神经网络输出层的softmax和cross_entrop何在一起计算,为了追求速度
    accuracy = tf.reduce_mean(tf.cast(tf.equal(tf.argmax(logits, 1), labels), tf.float32))

    global_step = tf.get_variable("step", [], initializer=tf.constant_initializer(0.0), trainable=False)  # global_step interesting  sharing varialbes
    rate = tf.train.exponential_decay(2e-4, global_step, decay_steps=2000, decay_rate=0.97, staircase=True)
    train_op = tf.train.AdamOptimizer(learning_rate=rate).minimize(loss, global_step=global_step)  #  train_op 包含了训练数据
    probabilities = tf.nn.softmax(logits)  # 上一个用logits是soft_max和cross_entropy一起算的,这次只是算了soft_max输出

    tf.summary.scalar(\'loss\', loss)
    tf.summary.scalar(\'accuracy\', accuracy)
    merged_summary_op = tf.summary.merge_all()
    predicted_val_top_k, predicted_index_top_k = tf.nn.top_k(probabilities, k=top_k)
    accuracy_in_top_k = tf.reduce_mean(tf.cast(tf.nn.in_top_k(probabilities, labels, top_k), tf.float32))  # 这个思路真是清奇!!!看来我回答对了

   # return the operator
    return {\'images\': images,
            \'labels\': labels,
            \'keep_prob\': keep_prob,
            \'top_k\': top_k,
            \'global_step\': global_step,
            \'train_op\': train_op,
            \'loss\': loss,
            \'accuracy\': accuracy,
            \'accuracy_top_k\': accuracy_in_top_k,
            \'merged_summary_op\': merged_summary_op,
            \'predicted_distribution\': probabilities,
            \'predicted_index_top_k\': predicted_index_top_k,
            \'predicted_val_top_k\': predicted_val_top_k}


def train():
    print(\'Begin training\')
    train_feeder = DataIterator(data_dir=\'../data/train/\')
    test_feeder = DataIterator(data_dir=\'../data/test/\')
    with tf.Session() as sess:
        # session操作之前启动队列runners才能激活pipelines/input pipeline 并载入数据
        train_images, train_labels = train_feeder.input_pipeline(batch_size=FLAGS.batch_size, aug=True)  # num_epochs what\'s refer to ?
        test_images, test_labels = test_feeder.input_pipeline(batch_size=FLAGS.batch_size)
        graph = build_graph(top_k=1)  # very important
        sess.run(tf.global_variables_initializer())
        # 4、 ## 启动queue线程
        coord = tf.train.Coordinator()
        threads = tf.train.start_queue_runners(sess=sess, coord=coord)
        saver = tf.train.Saver()

        train_writer = tf.summary.FileWriter(FLAGS.log_dir + \'/train\', sess.graph)
        test_writer = tf.summary.FileWriter(FLAGS.log_dir + \'/val\')
        start_step = 0
        if FLAGS.restore:  # 这里是加载保存好的模型,的到step继续训练
            ckpt = tf.train.latest_checkpoint(FLAGS.checkpoint_dir)
            if ckpt:
                saver.restore(sess, ckpt)
                print("restore from the checkpoint {0}".format(ckpt))
                start_step += int(ckpt.split(\'-\')[-1])

        logger.info(\':::Training Start:::\')
        try:
            while not coord.should_stop():  ###----
                start_time = time.time()
                print(start_time)
                train_images_batch, train_labels_batch = sess.run([train_images, train_labels])
                print(len(train_images_batch))
                feed_dict = {graph[\'images\']: train_images_batch,
                             graph[\'labels\']: train_labels_batch,
                             graph[\'keep_prob\']: 0.8}  # keep 80% connection
                _, loss_val, train_summary, step = sess.run(
                    [graph[\'train_op\'], graph[\'loss\'], graph[\'merged_summary_op\'], graph[\'global_step\']],
                    feed_dict=feed_dict)
                train_writer.add_summary(train_summary, step)
                end_time = time.time()
                logger.info("the step {0} takes {1} loss {2}".format(step, end_time - start_time, loss_val))
                if step > FLAGS.max_steps:
                    break
                if step % FLAGS.eval_steps == 1:
                    test_images_batch, test_labels_batch = sess.run([test_images, test_labels])
                    feed_dict = {graph[\'images\']: test_images_batch,
                                 graph[\'labels\']: test_labels_batch,
                                 graph[\'keep_prob\']: 1.0}
                    accuracy_test, test_summary = sess.run(
                        [graph[\'accuracy\'], graph[\'merged_summary_op\']],
                        feed_dict=feed_dict)  # 这里的多层括号问题
                    test_writer.add_summary(test_summary, step)
                    logger.info(\'===============Eval a batch=======================\')
                    logger.info(\'the step {0} test accuracy: {1}\'
                                .format(step, accuracy_test))
                    logger.info(\'===============Eval a batch=======================\')
                if step % FLAGS.save_steps == 1:
                    logger.info(\'Save the ckpt of {0}\'.format(step))
                    saver.save(sess, os.path.join(FLAGS.checkpoint_dir, \'my-model\'),
                               global_step=graph[\'global_step\'])
        except tf.errors.OutOfRangeError:
            logger.info(\'==================Train Finished================\')
            saver.save(sess, os.path.join(FLAGS.checkpoint_dir, \'my-model\'), global_step=graph[\'global_step\'])
        finally:
            coord.request_stop()  # 任何一个线程请求停止,则coord.should_stop()就会返回True ,然后都停下来
        coord.join(threads)


def validation():
    print(\'validation\')
    test_feeder = DataIterator(data_dir=\'../data/test/\')

    final_predict_val = []
    final_predict_index = []
    groundtruth = []

    with tf.Session() as sess:
        test_images, test_labels = test_feeder.input_pipeline(batch_size=FLAGS.batch_size, num_epochs=1)
        graph = build_graph(3)

        sess.run(tf.global_variables_initializer())
        sess.run(tf.local_variables_initializer())  # initialize test_feeder\'s inside state

        coord = tf.train.Coordinator()
        threads = tf.train.start_queue_runners(sess=sess, coord=coord)

        saver = tf.train.Saver()
        ckpt = tf.train.latest_checkpoint(FLAGS.checkpoint_dir)
        if ckpt is not None:
            saver.restore(sess, ckpt)
            print("restore from the checkpoint {0}".format(ckpt))

        logger.info(\':::Start validation:::\')
        try:
            i = 0
            acc_top_1, acc_top_k = 0.0, 0.0
            while not coord.should_stop():
                i += 1
                start_time = time.time()
                test_images_batch, test_labels_batch = sess.run([test_images, test_labels])
                feed_dict = {graph[\'images\']: test_images_batch,
                             graph[\'labels\']: test_labels_batch,
                             graph[\'keep_prob\']: 1.0}
                batch_labels, probs, indices, acc_1, acc_k = sess.run([graph[\'labels\'],
                                                                       graph[\'predicted_val_top_k\'],
                                                                       graph[\'predicted_index_top_k\'],
                                                                       graph[\'accuracy\'],
                                                                       graph[\'accuracy_top_k\']], feed_dict=feed_dict)
                final_predict_val += probs.tolist()
                final_predict_index += indices.tolist()
                groundtruth += batch_labels.tolist()
                acc_top_1 += acc_1
                acc_top_k += acc_k
                end_time = time.time()
                logger.info("the batch {0} takes {1} seconds, accuracy = {2}(top_1) {3}(top_k)"
                            .format(i, end_time - start_time, acc_1, acc_k))

        except tf.errors.OutOfRangeError:
            logger.info(\'==================Validation Finished================\')
            acc_top_1 = acc_top_1 * FLAGS.batch_size / test_feeder.size
            acc_top_k = acc_top_k * FLAGS.batch_size / test_feeder.size
            logger.info(\'top 1 accuracy {0} top k accuracy {1}\'.format(acc_top_1, acc_top_k))
        finally:
            coord.request_stop()
        coord.join(threads)
    return {\'prob\': final_predict_val, \'indices\': final_predict_index, \'groundtruth\': groundtruth}


def inference(image):
    print(\'inference\')
    temp_image = Image.open(image).convert(\'L\')
    temp_image = temp_image.resize((FLAGS.image_size, FLAGS.image_size), Image.ANTIALIAS)
    temp_image = np.asarray(temp_image) / 255.0
    temp_image = temp_image.reshape([-1, 64, 64, 1])
    with tf.Session() as sess:
        logger.info(\'========start inference============\')
        # images = tf.placeholder(dtype=tf.float32, shape=[None, 64, 64, 1])
        # Pass a shadow label 0. This label will not affect the computation graph.
        graph = build_graph(top_k=3)
        saver = tf.train.Saver()
        ckpt = tf.train.latest_checkpoint(FLAGS.checkpoint_dir)
        if ckpt:
            saver.restore(sess, ckpt)
        predict_val, predict_index = sess.run([graph[\'predicted_val_top_k\'], graph[\'predicted_index_top_k\']],
                                              feed_dict={graph[\'images\']: temp_image, graph[\'keep_prob\']: 1.0})
    return predict_val, predict_index


def main(_):
    print(FLAGS.mode)
    if FLAGS.mode == "train":
        train()
    elif FLAGS.mode == \'validation\':
        dct = validation()  # thinking what is "dct"
        result_file = \'result.dict\'
        logger.info(\'Write result into {0}\'.format(result_file))
        with open(result_file, \'wb\') as f:
            pickle.dump(dct, f)
        logger.info(\'Write file ends\')
    elif FLAGS.mode == \'inference\':
        image_path = \'../data/test/00159/75700.png\'
        final_predict_val, final_predict_index = inference(image_path)  # figure out what is inference
        logger.info(\'the result info label {0} predict index {1} predict_val {2}\'.format(190, final_predict_index,
                                                                                         final_predict_val))


if __name__ == "__main__":
    tf.app.run()  # It\'s just a very quick wrapper that handles flag parsing and then dispatches to your own main.

log_utils.py

# -*- coding:utf-8 -*-
import os, os.path as osp
import time


def strftime(t=None):
    return time.strftime("%Y%m%d-%H%M%S", time.localtime(t or time.time()))


#################
# Logging
#################
import logging
from logging.handlers import TimedRotatingFileHandler
logging.basicConfig(format="[ %(asctime)s][%(module)s.%(funcName)s] %(message)s")

DEFAULT_LEVEL = logging.INFO
DEFAULT_LOGGING_DIR = osp.join("logs", "gcforest")
fh = None


def init_fh():
    global fh
    if fh is not None:
        return
    if DEFAULT_LOGGING_DIR is None:
        return
    if not osp.exists(DEFAULT_LOGGING_DIR): os.makedirs(DEFAULT_LOGGING_DIR)
    logging_path = osp.join(DEFAULT_LOGGING_DIR, strftime() + ".log")
    fh = logging.FileHandler(logging_path)
    fh.setFormatter(logging.Formatter("[ %(asctime)s][%(module)s.%(funcName)s] %(message)s"))


def update_default_level(defalut_level):
    global DEFAULT_LEVEL
    DEFAULT_LEVEL = defalut_level


def update_default_logging_dir(default_logging_dir):
    global DEFAULT_LOGGING_DIR
    DEFAULT_LOGGING_DIR = default_logging_dir


def get_logger(name="HandWrittenPractice", level=None):
    level = level or DEFAULT_LEVEL
    logger = logging.getLogger(name)
    logger.setLevel(level)
    init_fh()
    if fh is not None:
        logger.addHandler(fh)
    return logger

Train

python chinese_write_detection.py --mode=train --max_steps=200000 --eval_steps=1000 --save_steps=10000

Validation

python chinese_write_detection.py --mode=validation

Inference

python chinese_write_detection.py --mode=inference