TFRecord读写简介+Demo 基于Ubuntu18.04+Tensorflow1.12 无WARNING

  • TFRecord是TensorFlow官方推荐使用的数据格式化存储工具。
  • 它规范了数据的读写方式。
  • 只要生成一次TFRecord,之后的数据读取和加工处理的效率都会得到提高。

将图片转换成TFRecord

本例,将fashion-MNIST数据转换成TFRecord,需要先下载fashion数据集到当前目录下,参考:https://github.com/zalandoresearch/fashion-mnist/tree/master/data/fashion

import numpy as np
import tensorflow as tf
import gzip
import os

fashion_mnist_directory = './data/fashion/'

def load_mnist(path, kind='train'):
    labels_path = os.path.join(path, '%s-labels-idx1-ubyte.gz' % kind)
    images_path = os.path.join(path, '%s-images-idx3-ubyte.gz' % kind)

    with gzip.open(labels_path, 'rb') as lbpath:
        labels = np.frombuffer(lbpath.read(), dtype=np.uint8, offset=8)

    with gzip.open(images_path, 'rb') as imgpath:
        images = np.frombuffer(imgpath.read(), dtype=np.uint8, offset=16).reshape(-1, 784)

    print(labels_path, "shape =", labels.shape)
    print(images_path, "shape =", images.shape)

    return images, labels


def make_example(image, label):
    return tf.train.Example(features=tf.train.Features(feature={
        'image_raw' : tf.train.Feature(bytes_list=tf.train.BytesList(value=[image.tobytes()])),
        'label' :     tf.train.Feature(int64_list=tf.train.Int64List(value=[int(label)     ])) }))


def write_tfrecord(images, labels, filename):
    writer = tf.python_io.TFRecordWriter(filename)
    for image, label, k in zip(images, labels, range(labels.shape[0])):
        exam = make_example(image, label)
        writer.write(exam.SerializeToString())
        if (k%100 == 0):
            print("\rwriting", filename, "%6.2f%% complited." %(100.0*(k+1)/labels.shape[0]), end='')
    
    print("\rwriting", filename, "%6.2f%% complited." %(100.0))
    writer.close()


def main():
    train_images, train_labels = load_mnist(fashion_mnist_directory, 'train')
    test_images, test_labels   = load_mnist(fashion_mnist_directory, 't10k')
    
    write_tfrecord(train_images, train_labels, 'fashion_mnist_train.tfrecords')
    write_tfrecord(test_images, test_labels, 'fashion_mnist_test.tfrecords')
    
if __name__ == '__main__':
    main()

读取TFRecord数据来训练

以下代码读取TFRecord数据用于训练,改代码改编自官方例程:https://github.com/tensorflow/tensorflow/tree/master/tensorflow/examples/how_tos/reading_data

原始代码运行时报错,已修复。

注意:在这个例子中,_, loss_value = sess.run([train_op, loss]),只执行一次Batch Input,无论[]中是什么,有多少个操作。

import argparse
import os.path
import sys
import time
import tensorflow as tf
from tensorflow.examples.tutorials.mnist import mnist

FLAGS = None

TRAIN_FILE = 'fashion_mnist_train.tfrecords'
VALIDATION_FILE = 'fashion_mnist_test.tfrecords'


def decode(serialized_example):
    features = tf.parse_single_example(serialized_example,
                                       features={'image_raw': tf.FixedLenFeature([], tf.string),
                                                 'label':     tf.FixedLenFeature([], tf.int64)})
    image = tf.decode_raw(features['image_raw'], tf.uint8)
    image.set_shape((mnist.IMAGE_PIXELS))
    label = tf.cast(features['label'], tf.int32)
    return image, label


def augment(image, label):
    """Placeholder for data augmentation."""
    # OPTIONAL: Could reshape into a 28x28 image and apply distortions here.
    return image, label


def normalize(image, label):
    """Convert `image` from [0, 255] -> [-0.5, 0.5] floats."""
    image = tf.cast(image, tf.float32) * (1. / 255) - 0.5
    return image, label


def inputs(train, batch_size, num_epochs):
    """Reads input data"""
    if not num_epochs:
        num_epochs = None
    filename = os.path.join(FLAGS.train_dir, TRAIN_FILE if train else VALIDATION_FILE)

    with tf.name_scope('input'):
        dataset = tf.data.TFRecordDataset(filename)
        dataset = dataset.map(decode)
        dataset = dataset.map(augment)
        dataset = dataset.map(normalize)
        dataset = dataset.shuffle(1000 + 3 * batch_size)
        dataset = dataset.repeat(num_epochs)
        dataset = dataset.batch(batch_size)
        iterator = dataset.make_one_shot_iterator()
    return iterator.get_next()


def run_training():
    with tf.Graph().as_default():
        image_batch, label_batch = inputs(train=True,
                                          batch_size=FLAGS.batch_size,
                                          num_epochs=FLAGS.num_epochs)
        logits = mnist.inference(image_batch, FLAGS.hidden1, FLAGS.hidden2)
        loss = mnist.loss(logits, label_batch)
        train_op = mnist.training(loss, FLAGS.learning_rate)
        
        init_op = tf.group(tf.global_variables_initializer(),
                           tf.local_variables_initializer())
                           
        with tf.Session() as sess:
            sess.run(init_op)
            try:
                step = 0
                while True:  # Train until OutOfRangeError
                    start_time = time.time()
                    _, loss_value = sess.run([train_op, loss])
                    duration = time.time() - start_time
                    if step % 100 == 0:
                        print('Step %d: loss = %.2f (%.3f sec)' % (step, loss_value, duration))
                    step += 1
            except tf.errors.OutOfRangeError:
                print('Done training for %d epochs, %d steps.' % (FLAGS.num_epochs, step))


def main(_):
    run_training()


if __name__ == '__main__':
    parser = argparse.ArgumentParser()
    parser.add_argument('--learning_rate', type=float, default=0.01, help='Initial learning rate.')
    parser.add_argument('--num_epochs',    type=int,   default=2,    help='Number of epochs to run trainer.')
    parser.add_argument('--hidden1',       type=int,   default=128,  help='Number of units in hidden layer 1.')
    parser.add_argument('--hidden2',       type=int,   default=32,   help='Number of units in hidden layer 2.')
    parser.add_argument('--batch_size',    type=int,   default=100,  help='Batch size.')
    parser.add_argument('--train_dir',     type=str,   default='./', help='Directory with the training data.')
    FLAGS, unparsed = parser.parse_known_args()
    tf.app.run(main=main, argv=[sys.argv[0]] + unparsed)

参考了:

  • https://blog.csdn.net/gg_18826075157/article/details/78449104
  • https://github.com/zalandoresearch/fashion-mnist/blob/master/utils/mnist_reader.py