tensorflow2.0中保存、加载、克隆模型

1. 在磁盘中保存与加载模型

1.1 保存与加载整个模型

保存整个模型:

  • 模型的架构/配置
  • 模型的权重值(在训练过程中学习)
  • 模型的编译信息(如果调用了 compile()
  • 优化器及其状态(如果有的话,使您可以从上次中断的位置重新开始训练)

保存模型

model.save(filepath)

或者

tf.keras.models.save_model(model, filepath)

注意:filepath的文件格式,如果不加后缀,默认是SavedModel格式,如果加后缀.h5,则是HDF5格式。后者相比前者更加轻量化,但包含内容不如前者。

加载模型

tf.keras.models.load_model(filepath)

注意:如果加载的是h5格式文件,那么可能会报错:AttributeError: ‘str’ object has no attribute \'decode。这是由于h5py版本过高导致,可以安装只能版本的h5py,即pip install h5py==2.10.0

举例

import tensorflow as tf
from tensorflow import keras

def get_model():
    model = keras.Sequential()
    model.add(keras.Input(shape=(1,)))
    model.add(keras.layers.Dense(10, keras.activations.relu))
    model.add(keras.layers.Dense(1))
    model.compile(optimizer=\'sgd\',  loss=\'mse\')
    return model

model_1 = get_model()

model_1.save("my_model.h5")
# 或者 model_1.save("my_model")

model_2 = tf.keras.models.load_model("my_model.h5")
# 或者 model_2 = tf.keras.models.load_model("my_model")

1.2 只保存与加载参数

保存参数

model.save_weights(filepath)

注意:filepath的文件格式,如果不加后缀,默认是TensorFlow Checkpoint格式,如果加后缀.h5,则是HDF5格式。具体差别可看官方文档。当网络存在嵌套时,后者可能会有问题。

加载参数

model.load_weights(filepath)

举例

import tensorflow as tf
from tensorflow import keras

def get_model():
    model = keras.Sequential()
    model.add(keras.Input(shape=(1,)))
    model.add(keras.layers.Dense(10, keras.activations.relu))
    model.add(keras.layers.Dense(1))
    model.compile(optimizer=\'sgd\',  loss=\'mse\')
    return model

model_1 = get_model()

model_1.save_weights("my_model_weights.h5")
# 或者 model_1.save_weights("my_model_weights")

model_1.load_weights("my_model_weights.h5")
# 或者 model_1.load_weights("my_model_weights.h5")

使用回调函数

使用回调函数同样也可以保存和加载模型参数

在训练时加入ModelCheckpoint回调函数:

cp_callback = tf.keras.callbacks.ModelCheckpoint(
    filepath=文件路径,  # 文件存储理解
    save_weights_only=True/False,  # 是否只保留参数
    save_best_only=True/False  # 是否只保留最优结果
)

model.fit(
  ...
  callbacks=[cp_callback]
)

# 加载模型参数
model.load_weights(文件路径)

2. 在内存中克隆模型

2.1 克隆整个模型

keras.models.clone_model(model)

注意:这里的model只能是functional modelsequential model,不能是subclass model

2.2 只克隆参数

获取摸个模型的参数

model.get_weights()

给某个模型的参数赋值

model.set_weights(weights)

参考

https://tensorflow.google.cn/guide/keras/save_and_serialize

https://www.bilibili.com/video/BV1B7411L7Qt?p=22