读取keras中的fashion_mnist数据集并查看

import tensorflow as tf
import matplotlib.pyplot as plt
from tensorflow import keras

fashion_mnist = keras.datasets.fashion_mnist
(train_X, train_y), (test_X,test_y) = fashion_mnist.load_data()
valid_X, train_X = train_X[:1000], train_X[1000:]
valid_y,  train_y = train_y[:1000], train_y[1000:]
plt.figure()
row = 3
col = 3
class_name = ['T-shirt', 'Trouser', 'pullover', 'dress', 'coat', 'sandal', 'shirt', 'sneaker', 'bag', 'ankle boot']
for r in range(row):
    for c in range(col):
        index = col*r + c + 1
        plt.subplot(row,col,index)
        plt.imshow(train_X[index], cmap='binary')
        plt.axis("off")
        plt.title(class_name[train_y[index]])
plt.show()

load_data可以自动划分为训练集和测试集,不过验证集需要自己划分。

注意plt.subplot的用法