三、TensorFlow模型的保存和加载
1、模型的保存:
import tensorflow as tf v1 = tf.Variable(1.0,dtype=tf.float32) v2 = tf.Variable(2.0,dtype=tf.float32) x = v1 + v2 saver = tf.train.Saver() with tf.Session() as sess: sess.run(tf.global_variables_initializer()) result = sess.run(x) #将模型保存在model文件夹下 saver.save(sess,\'./model/test.model\') print(\'result:{}\'.format(result))
2、模型的加载(直接加载图)
import tensorflow as tf saver = tf.train.import_meta_graph(\'./model/test.model.meta\') with tf.Session() as sess: saver.restore(sess,\'./model/test.model\') print(sess.run(tf.get_default_graph().get_tensor_by_name(\'add:0\')))
3、模型的加载(给定映射关系,主要用于不同开发之间模型的调用)
import tensorflow as tf a = tf.Variable(5.0,dtype=tf.float32,name=\'a\') b = tf.Variable(6.0,dtype=tf.float32,name=\'b\') x = a + b saver = tf.train.Saver({\'v1\':a,\'v2\':b}) with tf.Session() as sess: saver.restore(sess,\'./model/test.model\') print(sess.run([x]))