Tensorflow项目中--FLAGS=tf.flags.FLAGS

  最近看CycleGAN的代码,看到代码里有FLAGS=tf.flags.FLAGS等语句,看不明白,查寻之余发现,这类语句在使用Tensorflow框架的项目里是常见的。并且在看代码解释时,找到一个博主关于这部分只是的梳理,有解释加示例非常清楚,所以就直接应用该作者的文章。


内容包含如下几个我们经常看到的几个函数:

①tf.flags.DEFINE_xxx()

②FLAGS = tf.flags.FLAGS

③FLAGS._parse_flags()

简单的说:

   用于帮助我们添加命令行的可选参数。也就是说可以不用反复修改源代码中的参数,而是利用该函数可以实现在命令行中选择需要设定或者修改的参数来运行程序。

举个栗子:

程序train.py文件中的小部分代码如下所示:

 1 FLAGS = tf.flags.FLAGS
 2 
 3 tf.flags.DEFINE_string(\'name\', \'default\', \'name of the model\')
 4 tf.flags.DEFINE_integer(\'num_seqs\', 100, \'number of seqs in one batch\')
 5 tf.flags.DEFINE_integer(\'num_steps\', 100, \'length of one seq\')
 6 tf.flags.DEFINE_integer(\'lstm_size\', 128, \'size of hidden state of lstm\')
 7 tf.flags.DEFINE_integer(\'num_layers\', 2, \'number of lstm layers\')
 8 tf.flags.DEFINE_boolean(\'use_embedding\', False, \'whether to use embedding\')
 9 tf.flags.DEFINE_integer(\'embedding_size\', 128, \'size of embedding\')
10 tf.flags.DEFINE_float(\'learning_rate\', 0.001, \'learning_rate\')
11 tf.flags.DEFINE_float(\'train_keep_prob\', 0.5, \'dropout rate during training\')
12 tf.flags.DEFINE_string(\'input_file\', \'\', \'utf8 encoded text file\')
13 tf.flags.DEFINE_integer(\'max_steps\', 100000, \'max steps to train\')
14 tf.flags.DEFINE_integer(\'save_every_n\', 1000, \'save the model every n steps\')
15 tf.flags.DEFINE_integer(\'log_every_n\', 10, \'log to the screen every n steps\')
16 tf.flags.DEFINE_integer(\'max_vocab\', 3500, \'max char number\')

在命令行中我们为了执行train.py文件,在命令行中输入:

python train.py \
  --input_file data/shakespeare.txt  \         
  --name shakespeare \
  --num_steps 50 \
  --num_seqs 32 \
  --learning_rate 0.01 \
  --max_steps 20000

通过输入不同的文件名、参数,可以快速完成程序的调参和加载训练集的操作,不需要进入源码中更改。


实践操作

现在我们有如下代码:

 1 import tensorflow as tf
 2 #取上述代码中一部分进行实验  
 3 tf.flags.DEFINE_integer(\'num_seqs\', 100, \'number of seqs in one batch\')   
 4 tf.flags.DEFINE_integer(\'num_steps\', 100, \'length of one seq\')
 5 tf.flags.DEFINE_integer(\'lstm_size\', 128, \'size of hidden state of lstm\')
 6 
 7 #通过print()确定下面内容的功能
 8 FLAGS = tf.flags.FLAGS #FLAGS保存命令行参数的数据
 9 FLAGS._parse_flags() #将其解析成字典存储到FLAGS.__flags中
10 print(FLAGS.__flags)
11 
12 print(FLAGS.num_seqs)
13 
14 print("\nParameters:")
15 for attr, value in sorted(FLAGS.__flags.items()):
16     print("{}={}".format(attr.upper(), value))
17 print("")

按照我现在编写这个博客时间节点来说,第九行的 FLAGS._parse_flags() 在新版本的Tensorflow中不再使用了,如果因为版本造成编译出错,会返回AttributeError: _parse_flags。所以从另一个博主那看到新的的代码为 FLAGS.flag_values_dict() (解析成字典并且存储到FLAGS.__flags中)。

注意点:

  1. 修改参数的方式
  2. 调用参数的方式
  3. 描述参数的方式
  4. 定义参数的类型

原作者:ZQ_ZHU,链接:https://blog.csdn.net/zzq060143/article/details/81952848

修改代码的作者: spring_willow ,链接:https://blog.csdn.net/spring_willow/article/details/80115206

非常感谢作者的分享。