Tensorflow name_scope

在 Tensorflow 当中有两种途径生成变量 variable, 一种是 tf.get_variable(), 另一种是 tf.Variable().

使用tf.get_variable()定义的变量不会被tf.name_scope()当中的名字所影响

 1 import tensorflow as tf
 2 
 3 with tf.name_scope("a_name_scope"):
 4     initializer = tf.constant_initializer(value=1)
 5     var1 = tf.get_variable(name='var1', shape=[1], dtype=tf.float32, initializer=initializer)
 6     var2 = tf.Variable(name='var2', initial_value=[2], dtype=tf.float32)
 7     var21 = tf.Variable(name='var2', initial_value=[2.1], dtype=tf.float32)
 8     var22 = tf.Variable(name='var2', initial_value=[2.2], dtype=tf.float32)
 9 
10 
11 with tf.Session() as sess:
12     sess.run(tf.initialize_all_variables())
13     print(var1.name)        # var1:0
14     print(sess.run(var1))   # [ 1.]
15     print(var2.name)        # a_name_scope/var2:0
16     print(sess.run(var2))   # [ 2.]
17     print(var21.name)       # a_name_scope/var2_1:0
18     print(sess.run(var21))  # [ 2.0999999]
19     print(var22.name)       # a_name_scope/var2_2:0
20     print(sess.run(var22))  # [ 2.20000005]

想要达到重复利用变量的效果, 我们就要使用 tf.variable_scope(), 并搭配 tf.get_variable()这种方式产生和提取变量. 不像 tf.Variable() 每次都会产生新的变量, tf.get_variable() 如果遇到了同样名字的变量时, 它会单纯的提取这个同样名字的变量(避免产生新变量). 而在重复使用的时候, 一定要在代码中强调 scope.reuse_variables(), 否则系统将会报错, 以为你只是单纯的不小心重复使用到了一个变量.

 1 with tf.variable_scope("a_variable_scope") as scope:
 2     initializer = tf.constant_initializer(value=3)
 3     var3 = tf.get_variable(name='var3', shape=[1], dtype=tf.float32, initializer=initializer)
 4     scope.reuse_variables()
 5     var3_reuse = tf.get_variable(name='var3',)
 6     var4 = tf.Variable(name='var4', initial_value=[4], dtype=tf.float32)
 7     var4_reuse = tf.Variable(name='var4', initial_value=[4], dtype=tf.float32)
 8     
 9 with tf.Session() as sess:
10     sess.run(tf.global_variables_initializer())
11     print(var3.name)            # a_variable_scope/var3:0
12     print(sess.run(var3))       # [ 3.]
13     print(var3_reuse.name)      # a_variable_scope/var3:0
14     print(sess.run(var3_reuse)) # [ 3.]
15     print(var4.name)            # a_variable_scope/var4:0
16     print(sess.run(var4))       # [ 4.]
17     print(var4_reuse.name)      # a_variable_scope/var4_1:0
18     print(sess.run(var4_reuse)) # [ 4.]

1 with tf.variable_scope('foo') as foo_scope:
2     v = tf.get_variable('v', [1])
3 with tf.variable_scope('foo', reuse=True):
4     v1 = tf.get_variable('v')
5 assert v1 == v

1. 使用tf.Variable()的时候,tf.name_scope()tf.variable_scope() 都会给 Variableopname属性加上前缀。

2. 使用tf.get_variable()的时候,tf.name_scope()就不会给 tf.get_variable()创建出来的Variable加前缀。