TensorFlow的变量管理:变量作用域机制

在深度学习中,你可能需要用到大量的变量集,而且这些变量集可能在多处都要用到。例如,训练模型时,训练参数如权重(weights)、偏置(biases)等已经定下来,要拿到验证集去验证,我们自然希望这些参数是同一组。以往写简单的程序,可能使用全局限量就可以了,但在深度学习中,这显然是不行的,一方面不便管理,另外这样一来代码的封装性受到极大影响。因此,TensorFlow提供了一种变量管理方法:变量作用域机制,以此解决上面出现的问题。

TensorFlow的变量作用域机制依赖于以下两个方法,官方文档中定义如下:

[plain]view plaincopy

  1. tf.get_variable(name, shape, initializer): Creates or returns a variable with a given name.建立或返回一个给定名称的变量
  2. tf.variable_scope( scope_name): Manages namespaces for names passed to tf.get_variable(). 管理传递给tf.get_variable()的变量名组成的命名空间

先说说tf.get_variable(),这个方法在建立新的变量时与tf.Variable()完全相同。它的特殊之处在于,他还会搜索是否有同名的变量。创建变量用法如下:

[plain]view plaincopy

  1. with tf.variable_scope("foo"):
  2. with tf.variable_scope("bar"):
  3. v = tf.get_variable("v", [1])
  4. assert v.name == "foo/bar/v:0"

而tf.variable_scope(scope_name),它会管理在名为scope_name的域(scope)下传递给tf.get_variable的所有变量名(组成了一个变量空间),根据规则确定这些变量是否进行复用。这个方法最重要的参数是reuse,有None,tf.AUTO_REUSE与True三个选项。具体用法如下:

  1. reuse的默认选项是None,此时会继承父scope的reuse标志。
  2. 自动复用(设置reuse为tf.AUTO_REUSE),如果变量存在则复用,不存在则创建。这是最安全的用法,在使用新推出的EagerMode时reuse将被强制为tf.AUTO_REUSE选项。用法如下:

    [plain]view plaincopy

    1. def foo():
    2. with tf.variable_scope("foo", reuse=tf.AUTO_REUSE):
    3. v = tf.get_variable("v", [1])
    4. return v
    5. v1 = foo() # Creates v.
    6. v2 = foo() # Gets the same, existing v.
    7. assert v1 == v2
  3. 复用(设置reuse=True):

    [plain]view plaincopy

    1. with tf.variable_scope("foo"):
    2. v = tf.get_variable("v", [1])
    3. with tf.variable_scope("foo", reuse=True):
    4. v1 = tf.get_variable("v", [1])
    5. assert v1 == v
  4. 捕获某一域并设置复用(scope.reuse_variables()):

    [plain]view plaincopy

    1. with tf.variable_scope("foo") as scope:
    2. v = tf.get_variable("v", [1])
    3. scope.reuse_variables()
    4. v1 = tf.get_variable("v", [1])
    5. assert v1 == v

    1)非复用的scope下再次定义已存在的变量;或2)定义了复用但无法找到已定义的变量,TensorFlow都会抛出错误,具体如下:

[plain]view plaincopy

  1. with tf.variable_scope("foo"):
  2. v = tf.get_variable("v", [1])
  3. v1 = tf.get_variable("v", [1])
  4. # Raises ValueError("... v already exists ...").
  5. with tf.variable_scope("foo", reuse=True):
  6. v = tf.get_variable("v", [1])
  7. # Raises ValueError("... v does not exists ...").

转自: https://blog.csdn.net/zbgjhy88/article/details/78960388