【5】TensorFlow光速入门-图片分类完整代码

本文地址:https://www.cnblogs.com/tujia/p/13862364.html

系列文章:

【0】TensorFlow光速入门-序

【1】TensorFlow光速入门-tensorflow开发基本流程

【2】TensorFlow光速入门-数据预处理(得到数据集)

【3】TensorFlow光速入门-训练及评估

【4】TensorFlow光速入门-保存模型及加载模型并使用

【5】TensorFlow光速入门-图片分类完整代码

【6】TensorFlow光速入门-python模型转换为tfjs模型并使用

【7】TensorFlow光速入门-总结

一、完整代码

import pathlib
import random
import tensorflow as tf
from tensorflow import keras
import numpy as np
import IPython.display as display
import matplotlib.pyplot as plt

# 读取文件夹图片数据
data_path = \'/tf/datasets/wnw\'
all_image_paths = []
all_image_labels = []
label_names = []
data_root = pathlib.Path(data_path)
i = 0
for item in data_root.iterdir():
    label_names.append(item.name)
    for image in item.iterdir():
        all_image_paths.append(str(image))
        all_image_labels.append(i)
    i = i + 1
print(label_names)
print(len(all_image_paths))
print(len(all_image_labels))

# 抽样检查
image_count = len(all_image_paths)
for x in range(5):
    i = random.randint(0, image_count-1);
    image_path = all_image_paths[i]
    display.display(display.Image(image_path, width=100, height=100))
    print(label_names[all_image_labels[i]])

# 图片 转 tensor3D 格式
def load_and_preprocess_image(path):
    # 文件 转 tensor
    image = tf.io.read_file(path)
    # 普通 tensor 转 图片tensor,channels 为颜色通道,1表示灰图
    image = tf.image.decode_jpeg(image, channels=1)
    # 缩放图片尺寸为 100*100
    image = tf.image.resize(image, [100, 100])
    # 颜色的数值范围是0-255,所以 image/255,进一步将图片tensor数据数值范围缩到 0-1
    image /= 255
    return image

# 批量处理图片
path_ds = tf.data.Dataset.from_tensor_slices(all_image_paths)
image_ds = path_ds.map(load_and_preprocess_image, num_parallel_calls=tf.data.experimental.AUTOTUNE)

# 抽样检查
for i, image in enumerate(image_ds.take(5)):
    plt.imshow(image.numpy().squeeze(), cmap=plt.cm.gray_r)
    plt.grid(False)
    plt.xlabel(label_names[all_image_labels[i]])
    plt.show()

# label 数据集
label_ds = tf.data.Dataset.from_tensor_slices(tf.cast(all_image_labels, tf.int64))

# 打包图片及其label
image_label_ds = tf.data.Dataset.zip((image_ds, label_ds))

# 打乱数据
image_count = len(all_image_paths)
ds = image_label_ds.shuffle(buffer_size=image_count)
ds = ds.repeat()
ds = ds.batch(32)
ds = ds.prefetch(buffer_size=tf.data.experimental.AUTOTUNE)
print(ds)

# 模型初始化(配置神经网络层)
model = keras.Sequential([
    # 展平数据,输入类型要和数据集保持一致,我这里是100*100的灰图
    keras.layers.Flatten(input_shape=(100, 100, 1)),
    # 第二层是神经元
    keras.layers.Dense(128, activation=\'relu\'),
    # 第三层的参数很重要,2表示分两类,如果要分5类就传5,10类就传10
    keras.layers.Dense(2, activation=\'softmax\')
])

# 优化器、损失函数及指标
model.compile(optimizer=\'adam\',
              loss=\'sparse_categorical_crossentropy\',
              metrics=[\'accuracy\'])

# 训练 100 次
model.fit(ds, epochs=100, steps_per_epoch=10)

# 评估
test_loss, test_acc = model.evaluate(ds, verbose=2, steps=10)

# 预测
predictions = model.predict(ds, steps=10)
label = np.argmax(predictions[0])
print(label_names[label])

# 保存模型
model.save(\'/tf/saved_model/wnw\')

二、jupyter 笔记本

附件下载: wnw.ipynb

解压缩后,上传 wnw.ipynb 到 tensorflow-tutorials 目录就行了

参考【2】TensorFlow光速入门-数据预处理(得到数据集) 准备好图片数据后,直接运行 wnw.ipynb 就行了

注:图片数据需为jpg格式,不能用png或gif格式的,否则会报错~~

下一节,我们来看一下训练好的模型如果在 web 项目中应用:

【6】TensorFlow光速入门-python模型转换为tfjs模型并使用

本文链接:https://www.cnblogs.com/tujia/p/13862364.html


完。