TensorFlow slim,二 使用TF-slim编程模板

  TF-slim 模块是TensorFLow中比较实用的API之一,是一个用于模型构建、训练、评估复杂模型的轻量化库

  最近,在使用TF-slim API编写了一些项目模型后,发现TF-slim模块在搭建网络模型时具有相同的编写模式。这个编写模式主要包含四个部分:

  • __init__():
  • build_model():
  • fit():
  • predict():

1. __init__():

  这部分相当于是一个main()函数,其中包含参数的设置,模型整体的连接等操作。具体来说:

  a. 设置参数

  由于是类的构造函数,所以需要在其中设置一些模型网络结构的参数、模型训练时的参数等等。例如

  • 学习率
  • batch_size
  • 训练代数
  • 各种文件的存放地址
  • ...
  • 对于网络结构复杂的模型,还可以将网络结构的table以列表的形式进行保存。便于后续建立模型时可以循环获取每层的超参数。
1 self.lr = lr
2 self.batch_size = batch_size
3 self.epoch = epoch
4 self.checkpoint_dir_load = checkpoint_dir
5 self.checkpoint_dir = os.path.join(checkpoint_dir, filename + ".ckpt")
6 self.logdir = logdir
7 self.result_dir = result_dir

  b. 设置输入、输出的占位符placeholder

  由于TF-slim框架仍然采用的是tensorflow的那一套,不像tf.keras可以使用keras.layer.Input(),所以还需要使用占位符。例如

1 self.input_image = tf.placeholder(tf.float32, shape=[None, 6000])
2 self.input_image_raw = tf.reshape(self.input_image, shape=[-1, 6000, 1])
3 
4 self.input_image_label = tf.placeholder(tf.float32, shape=[None, 1, 10])
5 self.input_label = tf.reshape(self.input_image_label, shape=[-1, 10])

  c. 初始化网络结构,生成训练输出和测试输出

  用于后续损失的计算以及优化器的生成,以及训练结果和测试结果的调用。

  此处会涉及到网络参数的重用,需要使用tf.variable_scope()来管理参数。

1 with tf.variable_scope("Network_Structure") as scope:
2     self.train_digits = self.build_model(is_trained=True)
3     scope.reuse_variables()
4     self.test_digits = self.build_model(is_trained=False)

  d. 损失函数和优化器的声明

  此处损失声明使用的是 输出的占位符和训练的输出。例如:

1 self.loss = slim.losses.softmax_cross_entropy(logits=self.train_digits, onehot_labels=self.input_label, scope="loss")
2 
3 self.optimizer = tf.train.AdamOptimizer(learning_rate=self.lr).minimize(loss=self.loss)

  e. 最终训练输出结果和测试输出结果的计算

  由于网络输出的结果不一定是最终的结果。对于多分类问题,需要将one_hot编码的结果显示为类值;对于回归问题,输出结果可能会需要反归一化。等等..

  如下述代码,多分类问题的one_hot转化为类标签,并进行准确率的计算。

 1 # result and accuracy of test
 2 self.predicts = tf.math.argmax(self.test_digits, 1)   # 将one_hot转化为类标签
 3 self.test_correction = tf.equal(self.predicts, tf.math.argmax(self.input_label, 1))
 4 self.accuracy = tf.reduce_mean(tf.cast(self.test_correction, "float"))
 5 tf.summary.scalar("test_accuracy", self.accuracy)
 6 
 7 # result and accuracy of train
 8 self.train_result = tf.math.argmax(self.train_digits, 1)
 9 self.train_correlation = tf.equal(self.train_result, tf.math.argmax(self.input_label, 1))
10 self.train_accuracy = tf.reduce_mean(tf.cast(self.train_correlation, "float"))
11 tf.summary.scalar("train_accuracy", self.accuracy)

2. build_model():【可以是别的名字】

  这部分是为了使用tf-slim搭建网络模型结构。有些模型可能一个函数实现不了,需要多个函数。例如具有共享层的Siamese Network,在共享层后还有其他层。

  这一部分也实现了如同tf.keras搭建的模型"乐高式"堆叠,不需要手动为各层生成权重、偏执等参数。也是代码瘦身的重要环节。

 1 with slim.arg_scope([slim.conv1d], padding="SAME", stride=2, activation_fn=tf.nn.relu,
 2                     weights_initializer=tf.truncated_normal_initializer(stddev=0.01),
 3                     weights_regularizer=slim.l2_regularizer(0.005)
 4                     ):
 5     net = slim.conv1d(self.input_image_raw, num_outputs=16, kernel_size=8, padding="VALID", scope='conv_1')
 6     tf.summary.histogram("conv_1", net)
 7     net = slim.conv1d(net, num_outputs=16, kernel_size=8, scope='conv_2')
 8     tf.summary.histogram("conv_2", net)
 9     def_max_pool = tf.layers.MaxPooling1D(pool_size=2, strides=2, padding="VALID", name="max_pool_3")
10     net = def_max_pool(net)
11     # net = slim.nn.max_pool1d(net, ksize=2, strides=None, padding="VALID", data_format="NWC", name="max_pool_3")
12     tf.summary.histogram("max_pool_3", net)
13     net = slim.conv1d(net, num_outputs=64, kernel_size=4, scope="conv_4")
14     tf.summary.histogram("conv_4", net)
15     net = slim.conv1d(net, num_outputs=64, kernel_size=4, scope="conv_5")
16     tf.summary.histogram("conv_5", net)
17     def_max_pool = tf.layers.MaxPooling1D(pool_size=2, strides=2, padding="VALID", name="max_pool_6")
18     net = def_max_pool(net)
19     # net = slim.nn.max_pool1d(net, ksize=2, strides=1, padding="VALID", name="max_pool_6")
20     tf.summary.histogram("max_pool_6", net)
21     net = slim.conv1d(net, num_outputs=256, kernel_size=4, scope="conv_7")
22     tf.summary.histogram("conv_7", net)
23     net = slim.conv1d(net, num_outputs=256, kernel_size=4, scope="conv_8")
24     tf.summary.histogram("conv_8", net)
25     def_max_pool = tf.layers.MaxPooling1D(pool_size=2, strides=2, padding="VALID", name="max_pool_9")
26     net = def_max_pool(net)
27     # net = slim.nn.max_pool1d(net, ksize=1, strides=1, padding="VALID", name="max_pool_9")
28     tf.summary.histogram("max_pool_9", net)
29     net = slim.conv1d(net, num_outputs=512, kernel_size=2, stride=1, scope="conv_10")
30     tf.summary.histogram("conv_10", net)
31     net = slim.conv1d(net, num_outputs=512, kernel_size=2, stride=1, scope="conv_11")
32     tf.summary.histogram("conv_11", net)
33     def_max_pool = tf.layers.MaxPooling1D(pool_size=2, strides=2, padding="VALID", name="max_pool_12")
34     net = def_max_pool(net)
35     # net = slim.nn.max_pool1d(net, ksize=1, strides=1, padding="VALID", name="max_pool_12")
36     tf.summary.histogram("max_pool_12", net)
37     net = tf.reduce_mean(net, axis=1, name="global_max_pool_13")   # 起全局平均池化的作用
38     tf.summary.histogram("global_max_pool_13", net)
39     net = slim.dropout(net, keep_prob=0.5, scope="dropout")
40     tf.summary.histogram("dropout", net)
41     digits = slim.fully_connected(net, num_outputs=num_class, activation_fn=tf.nn.softmax, scope="fully_connected_14")
42     tf.summary.histogram("fully_connected_14", digits)
43 return digits

3. fit():

  看名字就知道这一部分需要完成的是训练部分的代码

  这一部分需要包含会话的启动、模型保存器的初始化、循环迭代、batch设置、数据集输入、输出数据获取、喂到网络中、保存模型、会话关闭等操作。如下述代码

 1 sess = tf.Session()  # 启动会话
 2 
 3 merge_summary_op = tf.summary.merge_all()
 4 summary_writer = tf.summary.FileWriter(self.logdir, sess.graph)
 5 
 6 saver = tf.train.Saver(max_to_keep=1)  # 生成保存器
 7 sess.run(tf.global_variables_initializer())   # 变量激活
 8 
 9 for step in range(self.epoch):    # 迭代
10     print("Epoch:%d"%step)
11     avg_cost = 0
12     acc = 0
13     total_batch = int(input_x.shape[0]/self.batch_size)   # 划分batch
14     for batch_num in range(total_batch):   # batch迭代
15         # 获取数据
16         batch_xs = input_x[batch_num*self.batch_size:(batch_num+1)*self.batch_size, :]
17         batch_ys = input_y[batch_num*self.batch_size:(batch_num+1)*self.batch_size, :]
18         batch_ys = sess.run(tf.one_hot(batch_ys, depth=10))
19         # 喂到损失 优化器等等
20         _, loss, acc = sess.run([self.optimizer, self.loss, self.train_accuracy],
21                                                         feed_dict={self.input_image: batch_xs,
22                                                          self.input_image_label: batch_ys})
23         avg_cost += loss / total_batch
24         acc += acc /total_batch
25 
26         summary_str = sess.run(merge_summary_op, feed_dict={self.input_image: batch_xs,
27                                                             self.input_image_label: batch_ys})
28         summary_writer.add_summary(summary_str, global_step=step)
29         print("Epoch:%d, batch: %d, avg_cost: %g, accuracy: %g" % (step, batch_num, avg_cost, acc))
30     # 保存模型
31     saver.save(sess, self.checkpoint_dir, global_step=step)
32 sess.close()   # 会话关闭

4. predict():

  从函数名可以知道这一部分是实现预测部分的代码。其相对于训练的过程要更简单。主要包括会话的启动、保存器的生成、权重的导入(模型的恢复)、预测、关闭会话。如下述代码

 1 sess = tf.Session()   # 会话的启动
 2 
 3 saver = tf.train.Saver()  # 保存器的生成
 4 
 5 module_file = tf.train.latest_checkpoint(self.checkpoint_dir_load)
 6 saver.restore(sess, module_file)    # 模型的恢复
 7 
 8 input_y = sess.run(tf.one_hot(input_y, depth=10))  # 获取输出
 9 # 获取预测结果和预测精度
10 predicts, acc_test = sess.run([self.predicts, self.accuracy], feed_dict={self.input_image: input_x,
11 # 关闭会话                                                                            self.input_image_label: input_y})
12 sess.close()
13 # print("test_accuracy: %f" %acc_test)
14 return predicts, acc_test

  上述四步完成后,便可以编写一个main函数来调用这个类,实现需要的功能。.fit()和.predict()主要是在main()函数来调用。