tensorflow 之模型的保存与加载,二

上一遍博文提到 有些场景下,可能只需要保存或加载部分变量,并不是所有隐藏层的参数都需要重新训练。

在实例化tf.train.Saver对象时,可以提供一个列表或字典来指定需要保存或加载的变量。

 1 #!/usr/bin/env python3           
 2 #-*- coding:utf-8 -*-            
 3 ############################     
 4 #File Name: restore.py           
 5 #Brief:                          
 6 #Author: frank                   
 7 #Mail: frank0903@aliyun.com      
 8 #Created Time:2018-06-22 22:34:16
 9 ############################     
10 
11 import tensorflow as tf
12 
13 v1 = tf.Variable(tf.constant(1.0, shape=[1]), name="v1")
14 v2 = tf.Variable(tf.constant(2.0, shape=[1]), name="v2")
15 print(v1)                                               
16 result = v1 + v2                                        
17 print(result)                                           
18                                                         
19 saver = tf.train.Saver([v1])#只有变量v1会被加载                                
20                                                         
21 with tf.Session() as sess:                              
22     saver.restore(sess, "my_test_model.ckpt")           
23     print(sess.run(result))                             

执行上面的代码,会报错:tensorflow.python.framework.errors_impl.FailedPreconditionError: Attempting to use uninitialized value v2

字典一般方便实现变量重命名,因为在某些场景下,模型中变量的命名和当前需要加载的变量名称并不相同而且有时候对于那些TF自动生成的变量的名称太长不好表示,那么为了不导致加载模型时找不到变量的问题。

在上一篇博文中,保存的两个变量的名称为v1和v2。

 1 import tensorflow as tf
 2 #保存或加载时给变量重命名
 3 a1 = tf.Variable(tf.constant(1.0, shape=[1]), name="other-v1")
 4 a2 = tf.Variable(tf.constant(2.0, shape=[1]), name="other-v2")
 5 result = a1 + a2
 6 
 7 #使用字典来重命名变量就可以加载原模型中的相应变量.如下指定了原来名称为v1的变量现在加载到变量a1中,原来名称为v2的变量现在加载到变量a2中
 8 saver = tf.train.Saver({"v1":a1, "v2":a2})
 9 #因为有时候模型保存时的变量名称和加载时的变量名称不一致,为了解决这个问题,TF可以通过字典将模型保存时的变量名和需要加载的变量关联起来.
10 
11 with tf.Session() as sess:
12     saver.restore(sess, "my_test_model.ckpt")
13     print(sess.run(result))                                                                                                          

在使用滑动平均模型时,tf.train.ExponentialMovingAverage对每一个变量会维护一个影子变量(shadow variable),这个影子变量是TF自动生成的,那么为了方便加载使用影子变量,就可以使用变量重命名。

 1 #!/usr/bin/env python3
 2 #-*- coding:utf-8 -*-
 3 ############################
 4 #File Name: saver_ema.py
 5 #Brief:
 6 #Author: frank
 7 #Mail: frank0903@aliyun.com
 8 #Created Time:2018-06-25 21:02:23
 9 ############################
10 import tensorflow as tf
11 
12 v = tf.Variable(0, dtype=tf.float32, name="v")
13 v2 = tf.Variable(0, dtype=tf.float32, name="v2")
14 for variables in tf.global_variables():
15     print(variables.name)
16 #v:0
17 #v2:0
18 
19 
20 #在声明滑动平均模型后,TF会自动生成一个影子变量                                                                                                                 
21 ema = tf.train.ExponentialMovingAverage(0.99)
22 maintain_averages_op = ema.apply(tf.global_variables())
23 for variables in tf.global_variables():
24     print(variables.name)
25 #v:0
26 #v2:0
27 #v/ExponentialMovingAverage:0
28 #v2/ExponentialMovingAverage:0
29 
30 print(ema.variables_to_restore())
31 #{'v2/ExponentialMovingAverage': <tf.Variable 'v2:0' shape=() dtype=float32_ref>, 'v/ExponentialMovingAverage': <tf.Variable 'v:0' shape=() dtype=float32_ref>}
32 
33 saver = tf.train.Saver()
34 
35 with tf.Session() as sess:
36     init_op = tf.global_variables_initializer()
37     sess.run(init_op)
38     
39     sess.run(tf.assign(v, 10))
40     sess.run(tf.assign(v2, 10))
41     sess.run(maintain_averages_op)
42     
43     saver.save(sess, "moving_average.ckpt")
44     print(sess.run([v, ema.average(v)]))
45 #[10.0, 0.099999905]

滑动平均模型主要作用是为了增加模型的泛化性,可针对所有参数进行优化。

在TF中,每一个变量的滑动平均值是通过影子变量维护的,所以要获得变量的滑动平均值实际上就是获取这个变量的影子变量的值。如果在加载模型时直接将影子变量映射到变量自身,那么在使用训练好的模型时,就不需要再进行相应变量的滑动平均值的计算。

 1 #!/usr/bin/env python3                                  
 2 #-*- coding:utf-8 -*-
 3 ############################
 4 #File Name: restore_ema.py
 5 #Brief:
 6 #Author: frank
 7 #Mail: frank0903@aliyun.com
 8 #Created Time:2018-06-25 21:51:31
 9 ############################
10 
11 import tensorflow as tf
12 
13 v = tf.Variable(0, dtype=tf.float32, name="v")
14 
15 saver = tf.train.Saver({"v/ExponentialMovingAverage":v})#通过变量重命名将原来的变量v的滑动平均值直接赋给变量v
16 
17 with tf.Session() as sess:
18     saver.restore(sess, "moving_average.ckpt")
19     print(sess.run(v))

源码路径:

https://github.com/suonikeyinsuxiao/tf_notes/blob/master/save_restore/saver_ema.py

https://github.com/suonikeyinsuxiao/tf_notes/blob/master/save_restore/restore_ema.py