tensorflow2.0中保存、加载、克隆模型
1. 在磁盘中保存与加载模型
1.1 保存与加载整个模型
保存整个模型:
- 模型的架构/配置
- 模型的权重值(在训练过程中学习)
- 模型的编译信息(如果调用了
compile()
) - 优化器及其状态(如果有的话,使您可以从上次中断的位置重新开始训练)
保存模型
model.save(filepath)
或者
tf.keras.models.save_model(model, filepath)
注意:filepath
的文件格式,如果不加后缀,默认是SavedModel
格式,如果加后缀.h5
,则是HDF5
格式。后者相比前者更加轻量化,但包含内容不如前者。
加载模型
tf.keras.models.load_model(filepath)
注意:如果加载的是h5
格式文件,那么可能会报错:AttributeError: ‘str’ object has no attribute \'decode
。这是由于h5py
版本过高导致,可以安装只能版本的h5py
,即pip install h5py==2.10.0
。
举例
import tensorflow as tf
from tensorflow import keras
def get_model():
model = keras.Sequential()
model.add(keras.Input(shape=(1,)))
model.add(keras.layers.Dense(10, keras.activations.relu))
model.add(keras.layers.Dense(1))
model.compile(optimizer=\'sgd\', loss=\'mse\')
return model
model_1 = get_model()
model_1.save("my_model.h5")
# 或者 model_1.save("my_model")
model_2 = tf.keras.models.load_model("my_model.h5")
# 或者 model_2 = tf.keras.models.load_model("my_model")
1.2 只保存与加载参数
保存参数
model.save_weights(filepath)
注意:filepath
的文件格式,如果不加后缀,默认是TensorFlow Checkpoint
格式,如果加后缀.h5
,则是HDF5
格式。具体差别可看官方文档。当网络存在嵌套时,后者可能会有问题。
加载参数
model.load_weights(filepath)
举例
import tensorflow as tf
from tensorflow import keras
def get_model():
model = keras.Sequential()
model.add(keras.Input(shape=(1,)))
model.add(keras.layers.Dense(10, keras.activations.relu))
model.add(keras.layers.Dense(1))
model.compile(optimizer=\'sgd\', loss=\'mse\')
return model
model_1 = get_model()
model_1.save_weights("my_model_weights.h5")
# 或者 model_1.save_weights("my_model_weights")
model_1.load_weights("my_model_weights.h5")
# 或者 model_1.load_weights("my_model_weights.h5")
使用回调函数
使用回调函数同样也可以保存和加载模型参数
在训练时加入ModelCheckpoint
回调函数:
cp_callback = tf.keras.callbacks.ModelCheckpoint(
filepath=文件路径, # 文件存储理解
save_weights_only=True/False, # 是否只保留参数
save_best_only=True/False # 是否只保留最优结果
)
model.fit(
...
callbacks=[cp_callback]
)
# 加载模型参数
model.load_weights(文件路径)
2. 在内存中克隆模型
2.1 克隆整个模型
keras.models.clone_model(model)
注意:这里的model
只能是functional model
或 sequential model
,不能是subclass model
2.2 只克隆参数
获取摸个模型的参数
model.get_weights()
给某个模型的参数赋值
model.set_weights(weights)