tensorflow 数据预处理

import  tensorflow as tf
from tensorflow import keras
def preprocess(x,y):
x = tf.cast(x, dtype = tf.float32) /255.
y = tf.cast(y, dtype = tf.int64)
y = tf.one_hot(y,depth = 10)
print('y shape :',y.shape)
return x,y
(x,y),(x_test,y_test) = keras.datasets.fashion_mnist.load_data()
db = tf.data.Dataset.from_tensor_slices((x,y))
db2 = db.map(preprocess).shuffle(60000).batch(100)
res = next(iter(db2))
print('res[0] shape',res[0].shape)
print('res[1] shape',res[1].shape)